1use crate::pool::{Driver, Pool};
8use crate::schema::Schema;
9use crate::Error;
10
11pub trait Migration: Send + Sync {
12 fn name(&self) -> &'static str;
13 fn up(&self, schema: &mut Schema);
14 fn down(&self, schema: &mut Schema);
15}
16
17inventory::collect!(MigrationRegistration);
18
19pub struct MigrationRegistration {
20 pub builder: fn() -> Box<dyn Migration>,
21}
22
23pub fn collected() -> Vec<Box<dyn Migration>> {
24 inventory::iter::<MigrationRegistration>
25 .into_iter()
26 .map(|r| (r.builder)())
27 .collect()
28}
29
30fn check_unique_names(migrations: &[Box<dyn Migration>]) {
36 use std::collections::HashSet;
37 let mut seen: HashSet<&'static str> = HashSet::with_capacity(migrations.len());
38 let mut dups: Vec<&'static str> = Vec::new();
39 for m in migrations {
40 if !seen.insert(m.name()) {
41 dups.push(m.name());
42 }
43 }
44 if !dups.is_empty() {
45 panic!(
46 "duplicate Migration::name() values: {dups:?}. \
47 A `name()` collision lets one migration silently shadow another. \
48 Check that each migration file's `fn name(&self) -> &'static str` matches its filename stem."
49 );
50 }
51}
52
53#[macro_export]
84macro_rules! migration {
85 (
86 $struct_name:ident,
87 $name:expr,
88 up = $up:expr,
89 down = $down:expr $(,)?
90 ) => {
91 pub struct $struct_name;
92
93 impl $crate::migration::Migration for $struct_name {
94 fn name(&self) -> &'static str {
95 $name
96 }
97 fn up(&self, schema: &mut $crate::schema::Schema) {
98 let f: fn(&mut $crate::schema::Schema) = $up;
99 f(schema);
100 }
101 fn down(&self, schema: &mut $crate::schema::Schema) {
102 let f: fn(&mut $crate::schema::Schema) = $down;
103 f(schema);
104 }
105 }
106
107 $crate::inventory::submit! {
108 $crate::migration::MigrationRegistration {
109 builder: || -> ::std::boxed::Box<dyn $crate::migration::Migration> {
110 ::std::boxed::Box::new($struct_name)
111 },
112 }
113 }
114 };
115}
116
117pub struct MigrationRunner {
118 pool: Pool,
119 migrations: Vec<Box<dyn Migration>>,
120}
121
122impl MigrationRunner {
123 pub fn new(pool: Pool) -> Self {
124 let mut migrations = collected();
125 check_unique_names(&migrations);
126 migrations.sort_by_key(|m| m.name().to_string());
127 Self { pool, migrations }
128 }
129
130 pub fn with_migrations(pool: Pool, mut migrations: Vec<Box<dyn Migration>>) -> Self {
131 check_unique_names(&migrations);
132 migrations.sort_by_key(|m| m.name().to_string());
133 Self { pool, migrations }
134 }
135
136 fn driver(&self) -> Driver {
137 self.pool.driver()
138 }
139
140 fn migrations_table_ddl(&self) -> &'static str {
143 match self.driver() {
144 Driver::Postgres => {
145 "CREATE TABLE IF NOT EXISTS migrations (
146 id BIGSERIAL PRIMARY KEY,
147 name TEXT NOT NULL UNIQUE,
148 batch INTEGER NOT NULL,
149 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
150 )"
151 }
152 Driver::MySql => {
153 "CREATE TABLE IF NOT EXISTS migrations (
154 id BIGINT AUTO_INCREMENT PRIMARY KEY,
155 name VARCHAR(255) NOT NULL UNIQUE,
156 batch INT NOT NULL,
157 applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
158 )"
159 }
160 Driver::Sqlite => {
161 "CREATE TABLE IF NOT EXISTS migrations (
162 id INTEGER PRIMARY KEY AUTOINCREMENT,
163 name TEXT NOT NULL UNIQUE,
164 batch INTEGER NOT NULL,
165 applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
166 )"
167 }
168 }
169 }
170
171 fn fresh_ddl(&self) -> Vec<&'static str> {
172 match self.driver() {
173 Driver::Postgres => vec!["DROP SCHEMA public CASCADE", "CREATE SCHEMA public"],
174 Driver::MySql => vec![
175 "",
178 ],
179 Driver::Sqlite => vec![
180 "PRAGMA writable_schema = 1",
181 "DELETE FROM sqlite_master WHERE type IN ('table','index','trigger')",
182 "PRAGMA writable_schema = 0",
183 "VACUUM",
184 ],
185 }
186 }
187
188 async fn exec(&self, sql: &str) -> Result<(), Error> {
191 if sql.is_empty() {
192 return Ok(());
193 }
194 match &self.pool {
195 Pool::Postgres(p) => {
196 sqlx::query(sql).execute(p).await?;
197 }
198 Pool::MySql(p) => {
199 sqlx::query(sql).execute(p).await?;
200 }
201 Pool::Sqlite(p) => {
202 sqlx::query(sql).execute(p).await?;
203 }
204 }
205 Ok(())
206 }
207
208 async fn applied_rows(&self) -> Result<Vec<(String, i32)>, Error> {
209 Ok(match &self.pool {
210 Pool::Postgres(p) => {
211 sqlx::query_as::<_, (String, i32)>(
212 "SELECT name, batch FROM migrations ORDER BY batch, id",
213 )
214 .fetch_all(p)
215 .await?
216 }
217 Pool::MySql(p) => {
218 sqlx::query_as::<_, (String, i32)>(
219 "SELECT name, batch FROM migrations ORDER BY batch, id",
220 )
221 .fetch_all(p)
222 .await?
223 }
224 Pool::Sqlite(p) => {
225 sqlx::query_as::<_, (String, i32)>(
226 "SELECT name, batch FROM migrations ORDER BY batch, id",
227 )
228 .fetch_all(p)
229 .await?
230 }
231 })
232 }
233
234 async fn max_batch(&self) -> Result<Option<i32>, Error> {
235 Ok(match &self.pool {
236 Pool::Postgres(p) => {
237 sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
238 .fetch_one(p)
239 .await?
240 .0
241 }
242 Pool::MySql(p) => {
243 sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
244 .fetch_one(p)
245 .await?
246 .0
247 }
248 Pool::Sqlite(p) => {
249 sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
250 .fetch_one(p)
251 .await?
252 .0
253 }
254 })
255 }
256
257 async fn names_in_batch(&self, batch: i32) -> Result<Vec<String>, Error> {
258 let rows: Vec<(String,)> = match &self.pool {
259 Pool::Postgres(p) => {
260 sqlx::query_as("SELECT name FROM migrations WHERE batch = $1 ORDER BY id DESC")
261 .bind(batch)
262 .fetch_all(p)
263 .await?
264 }
265 Pool::MySql(p) => {
266 sqlx::query_as("SELECT name FROM migrations WHERE batch = ? ORDER BY id DESC")
267 .bind(batch)
268 .fetch_all(p)
269 .await?
270 }
271 Pool::Sqlite(p) => {
272 sqlx::query_as("SELECT name FROM migrations WHERE batch = ?1 ORDER BY id DESC")
273 .bind(batch)
274 .fetch_all(p)
275 .await?
276 }
277 };
278 Ok(rows.into_iter().map(|(n,)| n).collect())
279 }
280
281 async fn record_applied(&self, name: &str, batch: i32) -> Result<(), Error> {
282 match &self.pool {
283 Pool::Postgres(p) => {
284 sqlx::query("INSERT INTO migrations (name, batch) VALUES ($1, $2)")
285 .bind(name)
286 .bind(batch)
287 .execute(p)
288 .await?;
289 }
290 Pool::MySql(p) => {
291 sqlx::query("INSERT INTO migrations (name, batch) VALUES (?, ?)")
292 .bind(name)
293 .bind(batch)
294 .execute(p)
295 .await?;
296 }
297 Pool::Sqlite(p) => {
298 sqlx::query("INSERT INTO migrations (name, batch) VALUES (?1, ?2)")
299 .bind(name)
300 .bind(batch)
301 .execute(p)
302 .await?;
303 }
304 }
305 Ok(())
306 }
307
308 async fn delete_applied(&self, name: &str) -> Result<(), Error> {
309 match &self.pool {
310 Pool::Postgres(p) => {
311 sqlx::query("DELETE FROM migrations WHERE name = $1")
312 .bind(name)
313 .execute(p)
314 .await?;
315 }
316 Pool::MySql(p) => {
317 sqlx::query("DELETE FROM migrations WHERE name = ?")
318 .bind(name)
319 .execute(p)
320 .await?;
321 }
322 Pool::Sqlite(p) => {
323 sqlx::query("DELETE FROM migrations WHERE name = ?1")
324 .bind(name)
325 .execute(p)
326 .await?;
327 }
328 }
329 Ok(())
330 }
331
332 async fn exec_many(&self, stmts: &[String]) -> Result<(), Error> {
333 for s in stmts {
334 self.exec(s).await?;
335 }
336 Ok(())
337 }
338
339 pub async fn ensure_table(&self) -> Result<(), Error> {
342 let ddl = self.migrations_table_ddl();
343 self.exec(ddl).await
344 }
345
346 pub async fn applied(&self) -> Result<Vec<String>, Error> {
347 Ok(self
348 .applied_rows()
349 .await?
350 .into_iter()
351 .map(|(n, _)| n)
352 .collect())
353 }
354
355 pub async fn next_batch(&self) -> Result<i32, Error> {
356 Ok(self.max_batch().await?.unwrap_or(0) + 1)
357 }
358
359 pub async fn run_up(&self) -> Result<Vec<String>, Error> {
360 self.ensure_table().await?;
361 let already = self.applied().await?;
362 let batch = self.next_batch().await?;
363 let mut known_tables: std::collections::HashSet<String> =
366 self.list_user_tables().await?.into_iter().collect();
367 let mut applied = Vec::new();
368 for m in &self.migrations {
369 if already.iter().any(|a| a == m.name()) {
370 continue;
371 }
372 let mut schema = Schema::for_driver(self.driver());
373 m.up(&mut schema);
374 check_fk_ordering(m.name(), &schema.statements, &mut known_tables)?;
375 self.exec_many(&schema.statements).await?;
376 self.record_applied(m.name(), batch).await?;
377 applied.push(m.name().to_string());
378 tracing::info!(name = m.name(), "migration applied");
379 }
380 Ok(applied)
381 }
382
383 async fn list_user_tables(&self) -> Result<Vec<String>, Error> {
386 Ok(match &self.pool {
387 Pool::Postgres(p) => sqlx::query_as::<_, (String,)>(
388 "SELECT tablename FROM pg_tables WHERE schemaname = 'public'",
389 )
390 .fetch_all(p)
391 .await?
392 .into_iter()
393 .map(|(t,)| t)
394 .collect(),
395 Pool::MySql(p) => sqlx::query_as::<_, (String,)>(
396 "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()",
397 )
398 .fetch_all(p)
399 .await?
400 .into_iter()
401 .map(|(t,)| t)
402 .collect(),
403 Pool::Sqlite(p) => sqlx::query_as::<_, (String,)>(
404 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
405 )
406 .fetch_all(p)
407 .await?
408 .into_iter()
409 .map(|(t,)| t)
410 .collect(),
411 })
412 }
413
414 pub async fn rollback(&self) -> Result<Vec<String>, Error> {
415 self.ensure_table().await?;
416 let Some(batch) = self.max_batch().await? else {
417 return Ok(Vec::new());
418 };
419 let names = self.names_in_batch(batch).await?;
420 let mut rolled = Vec::new();
421 for name in names {
422 let Some(m) = self.migrations.iter().find(|m| m.name() == name) else {
423 tracing::warn!(name, "migration row in DB but not registered; skipping");
424 continue;
425 };
426 let mut schema = Schema::for_driver(self.driver());
427 m.down(&mut schema);
428 self.exec_many(&schema.statements).await?;
429 self.delete_applied(&name).await?;
430 rolled.push(name);
431 }
432 Ok(rolled)
433 }
434
435 pub async fn fresh(&self) -> Result<(), Error> {
436 self.wipe().await?;
437 self.run_up().await?;
438 Ok(())
439 }
440
441 pub async fn wipe(&self) -> Result<(), Error> {
448 match self.driver() {
449 Driver::Postgres => {
450 for s in self.fresh_ddl() {
451 self.exec(s).await?;
452 }
453 }
454 Driver::MySql => {
455 self.drop_all_mysql_tables().await?;
456 }
457 Driver::Sqlite => {
458 self.drop_all_sqlite_tables().await?;
459 }
460 }
461 Ok(())
462 }
463
464 async fn drop_all_mysql_tables(&self) -> Result<(), Error> {
465 let Pool::MySql(p) = &self.pool else {
466 return Ok(());
467 };
468 let tables: Vec<(String,)> = sqlx::query_as(
469 "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()",
470 )
471 .fetch_all(p)
472 .await?;
473 sqlx::query("SET FOREIGN_KEY_CHECKS = 0").execute(p).await?;
474 for (t,) in tables {
475 sqlx::query(&format!("DROP TABLE IF EXISTS `{t}`"))
476 .execute(p)
477 .await?;
478 }
479 sqlx::query("SET FOREIGN_KEY_CHECKS = 1").execute(p).await?;
480 Ok(())
481 }
482
483 async fn drop_all_sqlite_tables(&self) -> Result<(), Error> {
484 let Pool::Sqlite(p) = &self.pool else {
485 return Ok(());
486 };
487 let tables: Vec<(String,)> = sqlx::query_as(
488 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
489 )
490 .fetch_all(p)
491 .await?;
492 for (t,) in tables {
493 sqlx::query(&format!("DROP TABLE IF EXISTS \"{t}\""))
494 .execute(p)
495 .await?;
496 }
497 Ok(())
498 }
499
500 pub async fn status(&self) -> Result<Vec<MigrationStatus>, Error> {
501 self.ensure_table().await?;
502 let rows = self.applied_rows().await?;
503 let applied_map: std::collections::HashMap<String, i32> = rows.into_iter().collect();
504
505 let mut out = Vec::new();
506 for m in &self.migrations {
507 let name = m.name().to_string();
508 let batch = applied_map.get(&name).copied();
509 out.push(MigrationStatus {
510 name,
511 applied: batch.is_some(),
512 batch,
513 });
514 }
515 for (db_name, batch) in &applied_map {
516 if !self.migrations.iter().any(|m| m.name() == db_name) {
517 out.push(MigrationStatus {
518 name: db_name.clone(),
519 applied: true,
520 batch: Some(*batch),
521 });
522 }
523 }
524 Ok(out)
525 }
526
527 pub async fn reset(&self) -> Result<Vec<String>, Error> {
528 self.ensure_table().await?;
529 let mut rolled_total = Vec::new();
530 loop {
531 let rolled = self.rollback().await?;
532 if rolled.is_empty() {
533 break;
534 }
535 rolled_total.extend(rolled);
536 }
537 Ok(rolled_total)
538 }
539
540 pub async fn refresh(&self) -> Result<Vec<String>, Error> {
541 self.reset().await?;
542 self.run_up().await
543 }
544
545 pub async fn run_up_step(&self) -> Result<Vec<String>, Error> {
546 self.ensure_table().await?;
547 let already = self.applied().await?;
548 let mut applied = Vec::new();
549 for m in &self.migrations {
550 if already.iter().any(|a| a == m.name()) {
551 continue;
552 }
553 let batch = self.next_batch().await?;
554 let mut schema = Schema::for_driver(self.driver());
555 m.up(&mut schema);
556 self.exec_many(&schema.statements).await?;
557 self.record_applied(m.name(), batch).await?;
558 applied.push(m.name().to_string());
559 tracing::info!(name = m.name(), batch, "migration applied (stepped)");
560 }
561 Ok(applied)
562 }
563
564 pub async fn pretend(&self) -> Result<Vec<String>, Error> {
565 self.ensure_table().await?;
566 let already = self.applied().await?;
567 let mut lines = Vec::new();
568 for m in &self.migrations {
569 if already.iter().any(|a| a == m.name()) {
570 continue;
571 }
572 lines.push(format!("-- migration: {}", m.name()));
573 let mut schema = Schema::for_driver(self.driver());
574 m.up(&mut schema);
575 for stmt in &schema.statements {
576 lines.push(format!("{stmt};"));
577 }
578 lines.push(String::new());
579 }
580 Ok(lines)
581 }
582
583 pub async fn install(&self) -> Result<(), Error> {
584 self.ensure_table().await
585 }
586
587 pub fn count(&self) -> usize {
588 self.migrations.len()
589 }
590}
591
592#[derive(Debug, Clone)]
594pub struct MigrationStatus {
595 pub name: String,
596 pub applied: bool,
597 pub batch: Option<i32>,
598}
599
600fn check_fk_ordering(
611 migration_name: &str,
612 statements: &[String],
613 known_tables: &mut std::collections::HashSet<String>,
614) -> Result<(), Error> {
615 for stmt in statements {
616 let Some(table) = parse_create_table_name(stmt) else {
617 continue;
618 };
619 for ref_table in parse_fk_references(stmt) {
620 if ref_table == table {
622 continue;
623 }
624 if !known_tables.contains(&ref_table) {
625 return Err(Error::Internal(format!(
626 "migration `{migration_name}` creates table `{table}` with a \
627 foreign key referencing `{ref_table}`, but `{ref_table}` \
628 hasn't been created yet.\n\n\
629 Migrations apply in alphabetical-by-filename order — bump \
630 the filename timestamp of the migration that creates \
631 `{ref_table}` so it sorts BEFORE `{migration_name}`."
632 )));
633 }
634 }
635 known_tables.insert(table);
636 }
637 Ok(())
638}
639
640fn parse_create_table_name(stmt: &str) -> Option<String> {
643 let trimmed = stmt.trim_start();
644 let upper = trimmed.to_ascii_uppercase();
645 let prefix_len = if upper.starts_with("CREATE TABLE IF NOT EXISTS ") {
646 "CREATE TABLE IF NOT EXISTS ".len()
647 } else if upper.starts_with("CREATE TABLE ") {
648 "CREATE TABLE ".len()
649 } else {
650 return None;
651 };
652 let rest = &trimmed[prefix_len..];
653 Some(parse_identifier(rest)?.0)
654}
655
656fn parse_fk_references(stmt: &str) -> Vec<String> {
658 let mut refs = Vec::new();
659 let upper = stmt.to_ascii_uppercase();
660 let mut cursor = 0;
661 while let Some(idx) = upper[cursor..].find("REFERENCES ") {
662 let abs = cursor + idx + "REFERENCES ".len();
663 let rest = &stmt[abs..];
664 if let Some((name, consumed)) = parse_identifier(rest) {
665 refs.push(name);
666 cursor = abs + consumed;
667 } else {
668 break;
669 }
670 }
671 refs
672}
673
674fn parse_identifier(s: &str) -> Option<(String, usize)> {
678 let bytes = s.as_bytes();
679 if bytes.is_empty() {
680 return None;
681 }
682 let (quote, body_start) = match bytes[0] {
683 b'"' => (Some(b'"'), 1),
684 b'`' => (Some(b'`'), 1),
685 _ => (None, 0),
686 };
687 let body = &s[body_start..];
688 let end_in_body = match quote {
689 Some(q) => body.bytes().position(|b| b == q)?,
690 None => body
691 .bytes()
692 .position(|b| matches!(b, b' ' | b'\t' | b'\n' | b'\r' | b'(' | b',' | b')'))
693 .unwrap_or(body.len()),
694 };
695 let name = body[..end_in_body].to_string();
696 let consumed = body_start + end_in_body + quote.map(|_| 1).unwrap_or(0);
697 if name.is_empty() {
698 None
699 } else {
700 Some((name, consumed))
701 }
702}
703
704#[cfg(test)]
705mod macro_tests {
706 use super::*;
707 use crate::schema::Schema;
708
709 crate::migration!(
712 TestCreateThingsTable,
713 "2026_01_01_000003_create_things_table",
714 up = |s| {
715 s.create("things", |t| {
716 t.id();
717 t.string("name").not_null();
718 });
719 },
720 down = |s| {
721 s.drop_if_exists("things");
722 },
723 );
724
725 #[test]
726 fn closure_migration_macro_expands_into_a_working_migration() {
727 let m = TestCreateThingsTable;
728 assert_eq!(m.name(), "2026_01_01_000003_create_things_table");
729
730 let mut s_up = Schema::for_driver(Driver::Sqlite);
734 m.up(&mut s_up);
735 assert!(
736 !s_up.statements.is_empty(),
737 "up() should emit at least one DDL statement"
738 );
739
740 let mut s_down = Schema::for_driver(Driver::Sqlite);
741 m.down(&mut s_down);
742 assert!(
743 !s_down.statements.is_empty(),
744 "down() should emit at least one DDL statement"
745 );
746 }
747
748 struct NamedMigration(&'static str);
749 impl Migration for NamedMigration {
750 fn name(&self) -> &'static str {
751 self.0
752 }
753 fn up(&self, _: &mut Schema) {}
754 fn down(&self, _: &mut Schema) {}
755 }
756
757 #[test]
758 fn check_unique_names_accepts_unique() {
759 let migs: Vec<Box<dyn Migration>> = vec![
760 Box::new(NamedMigration("2026_01_01_000001_a")),
761 Box::new(NamedMigration("2026_01_01_000002_b")),
762 Box::new(NamedMigration("2026_01_01_000003_c")),
763 ];
764 check_unique_names(&migs);
765 }
766
767 #[test]
768 #[should_panic(expected = "duplicate Migration::name() values")]
769 fn check_unique_names_panics_on_collision() {
770 let migs: Vec<Box<dyn Migration>> = vec![
771 Box::new(NamedMigration("2026_01_01_000001_a")),
772 Box::new(NamedMigration("2026_01_01_000001_a")),
773 ];
774 check_unique_names(&migs);
775 }
776}