Skip to main content

lift_migration/runner/
migrator.rs

1use std::{collections::HashMap, path::PathBuf};
2
3use quex::{self, Driver, FromRow, Pool, Row};
4
5use super::pool::MigrationPool;
6use crate::{
7    AlterColumnBuilder, BlueprintExecutor, ColumnBuilder, MigrationContext, MigrationError,
8    MigrationReport, ResetReport,
9    files::{
10        AlterAction, AlterEnumAction, Field, FieldType, MigrationFile, MigrationOp, Nullability,
11        Reference, ReferenceAction, TableItem, default_migration_dir, load_migrations,
12    },
13};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum CliCommand {
17    Run,
18    Refresh,
19    RollbackLastBatch,
20    RollbackSteps(usize),
21    Reset,
22    Validate,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct CliOptions {
27    pub database_url: String,
28    pub command: CliCommand,
29    pub migration_dir: Option<PathBuf>,
30}
31
32pub async fn run_cli_command(options: &CliOptions) -> Result<MigrationReport, MigrationError> {
33    let migrator = Migrator::connect(&options.database_url).await?.path(
34        options
35            .migration_dir
36            .clone()
37            .unwrap_or_else(default_migration_dir),
38    );
39    match options.command {
40        CliCommand::Run => migrator.run().await,
41        CliCommand::Refresh => migrator.refresh().await,
42        CliCommand::RollbackLastBatch => migrator.rollback_last_batch().await,
43        CliCommand::RollbackSteps(steps) => migrator.rollback_steps(steps).await,
44        CliCommand::Reset => {
45            let report = migrator.reset().await?;
46            Ok(MigrationReport {
47                batch: None,
48                applied: Vec::new(),
49                rolled_back: report.rolled_back,
50            })
51        }
52        CliCommand::Validate => {
53            migrator.validate()?;
54            Ok(MigrationReport::default())
55        }
56    }
57}
58
59struct BatchValue {
60    batch: Option<i64>,
61}
62
63impl FromRow for BatchValue {
64    fn from_row(row: &Row) -> quex::Result<Self> {
65        Ok(Self {
66            batch: row.get("batch")?,
67        })
68    }
69}
70
71struct ExistsValue {
72    value: i64,
73}
74
75impl FromRow for ExistsValue {
76    fn from_row(row: &Row) -> quex::Result<Self> {
77        Ok(Self {
78            value: row.get("value")?,
79        })
80    }
81}
82
83struct AppliedMigrationRow {
84    id: i64,
85    checksum: String,
86    status: String,
87}
88
89impl FromRow for AppliedMigrationRow {
90    fn from_row(row: &Row) -> quex::Result<Self> {
91        Ok(Self {
92            id: row.get("id")?,
93            checksum: row.get("checksum")?,
94            status: row.get("status")?,
95        })
96    }
97}
98
99struct AppliedMigrationId {
100    id: i64,
101}
102
103impl FromRow for AppliedMigrationId {
104    fn from_row(row: &Row) -> quex::Result<Self> {
105        Ok(Self { id: row.get("id")? })
106    }
107}
108
109pub struct Migrator {
110    pub(crate) pool: MigrationPool,
111    migration_dir: PathBuf,
112}
113
114impl Migrator {
115    pub fn new(pool: impl Into<Self>) -> Self {
116        pool.into()
117    }
118
119    pub fn path(mut self, path: impl Into<PathBuf>) -> Self {
120        self.migration_dir = path.into();
121        self
122    }
123
124    pub async fn connect(database_url: &str) -> Result<Self, MigrationError> {
125        let pool = Pool::connect(database_url)?.max_size(1).build().await?;
126        let migrator = match pool.driver() {
127            Driver::Sqlite => {
128                #[cfg(feature = "sqlite")]
129                {
130                    Ok::<Self, MigrationError>(Self {
131                        pool: MigrationPool::Sqlite(pool),
132                        migration_dir: default_migration_dir(),
133                    })
134                }
135                #[cfg(not(feature = "sqlite"))]
136                {
137                    Err(MigrationError::BackendNotEnabled("sqlite"))
138                }
139            }
140            Driver::Pgsql => {
141                #[cfg(feature = "postgres")]
142                {
143                    Ok::<Self, MigrationError>(Self {
144                        pool: MigrationPool::Postgres(pool),
145                        migration_dir: default_migration_dir(),
146                    })
147                }
148                #[cfg(not(feature = "postgres"))]
149                {
150                    Err(MigrationError::BackendNotEnabled("postgres"))
151                }
152            }
153            Driver::Mysql => {
154                #[cfg(feature = "mariadb")]
155                {
156                    Ok::<Self, MigrationError>(Self {
157                        pool: MigrationPool::Mariadb(pool),
158                        migration_dir: default_migration_dir(),
159                    })
160                }
161                #[cfg(not(feature = "mariadb"))]
162                {
163                    Err(MigrationError::BackendNotEnabled("mariadb"))
164                }
165            }
166        }?;
167        migration_trace!(
168            backend = migrator.pool.backend_name(),
169            "connecting migrator"
170        );
171        Ok(migrator)
172    }
173
174    pub fn validate(&self) -> Result<(), MigrationError> {
175        let _ = self.load_migrations()?;
176        Ok(())
177    }
178
179    pub async fn reset(&self) -> Result<ResetReport, MigrationError> {
180        let mut rolled_back = Vec::new();
181        let mut batches = Vec::new();
182
183        loop {
184            let batch_report = self.rollback_last_batch().await?;
185            if batch_report.rolled_back.is_empty() {
186                break;
187            }
188
189            rolled_back.extend(batch_report.rolled_back);
190            if let Some(batch) = batch_report.batch {
191                batches.push(batch);
192            }
193        }
194
195        Ok(ResetReport {
196            rolled_back,
197            batches,
198        })
199    }
200
201    pub async fn refresh(&self) -> Result<MigrationReport, MigrationError> {
202        let reset_report = self.reset().await?;
203        let up_report = self.run().await?;
204
205        Ok(MigrationReport {
206            batch: up_report.batch,
207            applied: up_report.applied,
208            rolled_back: reset_report.rolled_back,
209        })
210    }
211
212    pub async fn run(&self) -> Result<MigrationReport, MigrationError> {
213        migration_trace!(backend = self.pool.backend_name(), "starting migration run");
214        self.ensure_table().await?;
215
216        let migrations = self.load_migrations()?;
217        let mut applied = self.applied_migrations().await?;
218        let mut retryable_failed_ids = Vec::new();
219
220        for migration in &migrations {
221            if let Some(row) = applied.get(&migration.id) {
222                if row.status == "failed" {
223                    if self.pool.supports_clean_failed_retry() {
224                        retryable_failed_ids.push(migration.id);
225                        continue;
226                    }
227                    return Err(MigrationError::FailedMigration { id: migration.id });
228                }
229                if row.checksum != migration.checksum {
230                    return Err(MigrationError::ChecksumMismatch { id: migration.id });
231                }
232            }
233        }
234
235        for id in retryable_failed_ids {
236            self.delete_applied(id).await?;
237            applied.remove(&id);
238        }
239
240        let pending: Vec<_> = migrations
241            .into_iter()
242            .filter(|migration| !applied.contains_key(&migration.id))
243            .collect();
244
245        if pending.is_empty() {
246            return Ok(MigrationReport::default());
247        }
248
249        let batch = self.next_batch().await?;
250        let mut report = MigrationReport {
251            batch: Some(batch),
252            applied: Vec::with_capacity(pending.len()),
253            rolled_back: Vec::new(),
254        };
255
256        for migration in pending {
257            self.mark_running(&migration, batch).await?;
258            match self.apply(&migration, &migration.up).await {
259                Ok(()) => {
260                    self.mark_applied(&migration.id).await?;
261                    report.applied.push(migration.id);
262                }
263                Err(error) => {
264                    self.mark_failed(migration.id).await?;
265                    return Err(error);
266                }
267            }
268        }
269
270        Ok(report)
271    }
272
273    pub async fn rollback_last_batch(&self) -> Result<MigrationReport, MigrationError> {
274        self.ensure_table().await?;
275
276        let Some(batch) = self.last_batch().await? else {
277            return Ok(MigrationReport::default());
278        };
279
280        let ids = self.ids_for_batch(batch).await?;
281        self.rollback_ids(ids, Some(batch)).await
282    }
283
284    pub async fn rollback_steps(&self, steps: usize) -> Result<MigrationReport, MigrationError> {
285        self.ensure_table().await?;
286        if steps == 0 {
287            return Ok(MigrationReport::default());
288        }
289
290        let ids = self.latest_ids(steps).await?;
291        self.rollback_ids(ids, None).await
292    }
293
294    fn load_migrations(&self) -> Result<Vec<MigrationFile>, MigrationError> {
295        load_migrations(&self.migration_dir)
296    }
297
298    async fn rollback_ids(
299        &self,
300        ids: Vec<u64>,
301        batch: Option<u64>,
302    ) -> Result<MigrationReport, MigrationError> {
303        if ids.is_empty() {
304            return Ok(MigrationReport::default());
305        }
306
307        let migrations = self
308            .load_migrations()?
309            .into_iter()
310            .map(|migration| (migration.id, migration))
311            .collect::<HashMap<_, _>>();
312
313        let mut report = MigrationReport {
314            batch,
315            applied: Vec::new(),
316            rolled_back: Vec::with_capacity(ids.len()),
317        };
318
319        for id in ids {
320            let migration = migrations
321                .get(&id)
322                .ok_or(MigrationError::MissingMigration(id))?;
323            self.apply(migration, &migration.down).await?;
324            self.delete_applied(id).await?;
325            report.rolled_back.push(id);
326        }
327
328        Ok(report)
329    }
330
331    async fn apply(
332        &self,
333        migration: &MigrationFile,
334        ops: &[MigrationOp],
335    ) -> Result<(), MigrationError> {
336        let mut tx = self.pool.pool().begin().await?;
337        let result = {
338            let mut ctx = self.new_transaction_context(&mut tx);
339            apply_operations(&mut ctx, &migration.path, ops).await
340        };
341        match result {
342            Ok(()) => tx.commit().await.map_err(Into::into),
343            Err(error) => {
344                let _ = tx.rollback().await;
345                Err(error)
346            }
347        }
348    }
349
350    fn new_transaction_context<'a>(
351        &self,
352        tx: &'a mut quex::PoolTransaction,
353    ) -> MigrationContext<'a> {
354        match &self.pool {
355            #[cfg(feature = "sqlite")]
356            MigrationPool::Sqlite(_) => {
357                MigrationContext::Sqlite(crate::SqliteMigrationContext::from_transaction(tx))
358            }
359            #[cfg(feature = "postgres")]
360            MigrationPool::Postgres(_) => {
361                MigrationContext::Postgres(crate::PostgresMigrationContext::from_transaction(tx))
362            }
363            #[cfg(feature = "mariadb")]
364            MigrationPool::Mariadb(_) => {
365                MigrationContext::Mariadb(crate::MariadbMigrationContext::from_transaction(tx))
366            }
367            #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
368            MigrationPool::Disabled => MigrationContext::Disabled(std::marker::PhantomData),
369        }
370    }
371
372    async fn ensure_table(&self) -> Result<(), MigrationError> {
373        let exists_sql = match &self.pool {
374            #[cfg(feature = "sqlite")]
375            MigrationPool::Sqlite(_) => {
376                "select 1 as value from sqlite_master where type = 'table' and name = 'migrations' limit 1"
377            }
378            #[cfg(feature = "postgres")]
379            MigrationPool::Postgres(_) => {
380                "select 1 as value from pg_catalog.pg_class c join pg_catalog.pg_namespace n on n.oid = c.relnamespace where c.relkind = 'r' and c.relname = 'migrations' and n.nspname = current_schema() limit 1"
381            }
382            #[cfg(feature = "mariadb")]
383            MigrationPool::Mariadb(_) => {
384                "select 1 as value from information_schema.tables where table_schema = database() and table_name = 'migrations' limit 1"
385            }
386            #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
387            MigrationPool::Disabled => return Err(MigrationError::BackendNotEnabled("no backend")),
388        };
389        let exists = quex::query(exists_sql)
390            .optional::<ExistsValue>(self.pool.pool())
391            .await?
392            .map(|row| row.value != 0)
393            .unwrap_or(false);
394        if exists {
395            return Ok(());
396        }
397
398        let sql = match &self.pool {
399            #[cfg(feature = "sqlite")]
400            MigrationPool::Sqlite(_) => {
401                "create table if not exists migrations (id integer primary key not null, name text not null, checksum text not null, batch integer not null, status text not null, started_at text not null default current_timestamp, finished_at text null)"
402            }
403            #[cfg(feature = "postgres")]
404            MigrationPool::Postgres(_) => {
405                "create table if not exists migrations (id bigint primary key not null, name text not null, checksum text not null, batch bigint not null, status text not null, started_at timestamptz not null default current_timestamp, finished_at timestamptz null)"
406            }
407            #[cfg(feature = "mariadb")]
408            MigrationPool::Mariadb(_) => {
409                "create table if not exists migrations (id bigint primary key not null, name varchar(255) not null, checksum varchar(255) not null, batch bigint not null, status varchar(32) not null, started_at timestamp not null default current_timestamp, finished_at timestamp null)"
410            }
411            #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
412            MigrationPool::Disabled => return Err(MigrationError::BackendNotEnabled("no backend")),
413        };
414        quex::query(sql).execute(self.pool.pool()).await?;
415        Ok(())
416    }
417
418    async fn applied_migrations(
419        &self,
420    ) -> Result<HashMap<u64, AppliedMigrationRow>, MigrationError> {
421        let rows = quex::query("select id, checksum, status from migrations order by id")
422            .all::<AppliedMigrationRow>(self.pool.pool())
423            .await?;
424        Ok(rows
425            .into_iter()
426            .map(|row| (row.id as u64, row))
427            .collect::<HashMap<_, _>>())
428    }
429
430    async fn next_batch(&self) -> Result<u64, MigrationError> {
431        Ok(self.max_batch().await?.unwrap_or(0) + 1)
432    }
433
434    async fn last_batch(&self) -> Result<Option<u64>, MigrationError> {
435        self.max_batch().await
436    }
437
438    async fn max_batch(&self) -> Result<Option<u64>, MigrationError> {
439        let row =
440            quex::query("select max(batch) as batch from migrations where status = 'applied'")
441                .one::<BatchValue>(self.pool.pool())
442                .await?;
443        Ok(row.batch.map(|value| value as u64))
444    }
445
446    async fn ids_for_batch(&self, batch: u64) -> Result<Vec<u64>, MigrationError> {
447        let ids = quex::query(
448            "select id from migrations where batch = ? and status = 'applied' order by id desc",
449        )
450        .bind(batch as i64)
451        .all::<AppliedMigrationId>(self.pool.pool())
452        .await?;
453        Ok(ids.into_iter().map(|value| value.id as u64).collect())
454    }
455
456    async fn latest_ids(&self, steps: usize) -> Result<Vec<u64>, MigrationError> {
457        let ids = quex::query("select id from migrations where status = 'applied' order by batch desc, id desc limit ?")
458            .bind(steps as i64)
459            .all::<AppliedMigrationId>(self.pool.pool())
460            .await?;
461        Ok(ids.into_iter().map(|value| value.id as u64).collect())
462    }
463
464    async fn mark_running(
465        &self,
466        migration: &MigrationFile,
467        batch: u64,
468    ) -> Result<(), MigrationError> {
469        quex::query("delete from migrations where id = ?")
470            .bind(migration.id as i64)
471            .execute(self.pool.pool())
472            .await?;
473        quex::query(
474            "insert into migrations(id, name, checksum, batch, status) values(?, ?, ?, ?, ?)",
475        )
476        .bind(migration.id as i64)
477        .bind(&migration.name)
478        .bind(&migration.checksum)
479        .bind(batch as i64)
480        .bind("running")
481        .execute(self.pool.pool())
482        .await?;
483        Ok(())
484    }
485
486    async fn mark_applied(&self, id: &u64) -> Result<(), MigrationError> {
487        quex::query(
488            "update migrations set status = ?, finished_at = current_timestamp where id = ?",
489        )
490        .bind("applied")
491        .bind(*id as i64)
492        .execute(self.pool.pool())
493        .await?;
494        Ok(())
495    }
496
497    async fn mark_failed(&self, id: u64) -> Result<(), MigrationError> {
498        quex::query(
499            "update migrations set status = ?, finished_at = current_timestamp where id = ?",
500        )
501        .bind("failed")
502        .bind(id as i64)
503        .execute(self.pool.pool())
504        .await?;
505        Ok(())
506    }
507
508    async fn delete_applied(&self, id: u64) -> Result<(), MigrationError> {
509        quex::query("delete from migrations where id = ?")
510            .bind(id as i64)
511            .execute(self.pool.pool())
512            .await?;
513        Ok(())
514    }
515}
516
517trait ColumnAttrs: Sized {
518    fn nullable(self) -> Self;
519    fn unique(self) -> Self;
520    fn default_raw(self, value: &str) -> Self;
521}
522
523impl<'a> ColumnAttrs for ColumnBuilder<'a> {
524    fn nullable(self) -> Self {
525        Self::nullable(self)
526    }
527
528    fn unique(self) -> Self {
529        Self::unique(self)
530    }
531
532    fn default_raw(self, value: &str) -> Self {
533        Self::default_raw(self, value)
534    }
535}
536
537impl<'a> ColumnAttrs for AlterColumnBuilder<'a> {
538    fn nullable(self) -> Self {
539        Self::nullable(self)
540    }
541
542    fn unique(self) -> Self {
543        Self::unique(self)
544    }
545
546    fn default_raw(self, value: &str) -> Self {
547        Self::default_raw(self, value)
548    }
549}
550
551impl<'a> ColumnAttrs for crate::ForeignKeyBuilder<'a> {
552    fn nullable(self) -> Self {
553        Self::nullable(self)
554    }
555
556    fn unique(self) -> Self {
557        Self::unique(self)
558    }
559
560    fn default_raw(self, value: &str) -> Self {
561        Self::default_raw(self, value)
562    }
563}
564
565fn apply_field_attrs<B>(builder: B, field: &Field) -> B
566where
567    B: ColumnAttrs,
568{
569    let mut builder = builder;
570    if matches!(field.nullable, Nullability::Nullable) {
571        builder = builder.nullable();
572    }
573    if field.unique {
574        builder = builder.unique();
575    }
576    if let Some(default) = &field.default {
577        builder = builder.default_raw(default);
578    }
579    builder
580}
581
582async fn apply_operations(
583    ctx: &mut MigrationContext<'_>,
584    path: &std::path::Path,
585    ops: &[MigrationOp],
586) -> Result<(), MigrationError> {
587    for op in ops {
588        match op {
589            MigrationOp::Sql { sql } => {
590                ctx.execute_raw(sql).await?;
591            }
592            MigrationOp::Backfill { sql } => {
593                ctx.execute_raw(sql).await?;
594            }
595            MigrationOp::CreateTable { name, items } => {
596                let items = resolve_table_items(ctx, path, items).await?;
597                ctx.create(name, |table| {
598                    for item in &items {
599                        match item {
600                            TableItem::Column(field) => apply_create_field(table, field),
601                            TableItem::Index {
602                                name: index_name,
603                                columns,
604                            } => table.index(
605                                index_name
606                                    .as_deref()
607                                    .unwrap_or(&default_index_name(name, columns)),
608                                columns.clone(),
609                            ),
610                            TableItem::Unique(columns) => table.unique(columns.clone()),
611                            TableItem::ConstraintUnique { name, columns } => {
612                                table.unique_named(name, columns.clone())
613                            }
614                            TableItem::Primary(columns) => table.primary(columns.clone()),
615                            TableItem::Timestamps => table.timestamps(),
616                            TableItem::SoftDeletes => {
617                                table.timestamp("deleted_at").nullable();
618                            }
619                        }
620                    }
621                })
622                .await?;
623            }
624            MigrationOp::AlterTable { name, actions } => {
625                let actions = resolve_alter_actions(ctx, path, actions).await?;
626                for action in &actions {
627                    if let AlterAction::DropIndex(index_name) = action {
628                        let index = crate::IndexBlueprint::named(index_name);
629                        ctx.execute_raw_blueprint(&index.drop_sql(ctx.dialect()))
630                            .await?;
631                    }
632                }
633                ctx.alter_table(name, |table| {
634                    for action in &actions {
635                        match action {
636                            AlterAction::AddColumn(field) if field.reference.is_none() => {
637                                apply_alter_field(table, field)
638                            }
639                            AlterAction::DropColumn(name) => table.drop_column(name),
640                            AlterAction::RenameColumn { .. } => {}
641                            AlterAction::AddIndex { .. } => {}
642                            AlterAction::DropIndex(_) | AlterAction::AddColumn(_) => {}
643                        }
644                    }
645                })
646                .await?;
647                for action in &actions {
648                    match action {
649                        AlterAction::RenameColumn { from, to } => {
650                            let sql = format!(
651                                "alter table {} rename column {} to {};",
652                                ctx.dialect().quote_ident(name),
653                                ctx.dialect().quote_ident(from),
654                                ctx.dialect().quote_ident(to)
655                            );
656                            ctx.execute_raw_blueprint(&sql).await?;
657                        }
658                        AlterAction::AddIndex {
659                            name: index_name,
660                            columns,
661                        } => {
662                            let index = crate::IndexBlueprint::new(
663                                index_name
664                                    .as_deref()
665                                    .unwrap_or(&default_index_name(name, columns)),
666                                name,
667                                columns.clone(),
668                            );
669                            ctx.execute_raw_blueprint(&index.create_sql(ctx.dialect()))
670                                .await?;
671                        }
672                        AlterAction::AddColumn(field) if field.reference.is_some() => {
673                            add_relation_column(ctx, name, field).await?;
674                        }
675                        AlterAction::DropIndex(_) => {}
676                        AlterAction::AddColumn(_) | AlterAction::DropColumn(_) => {}
677                    }
678                }
679            }
680            MigrationOp::DropTable { name } => {
681                ctx.drop(name).await?;
682            }
683            MigrationOp::RenameTable { from, to } => {
684                let sql = format!(
685                    "alter table {} rename to {};",
686                    ctx.dialect().quote_ident(from),
687                    ctx.dialect().quote_ident(to)
688                );
689                ctx.execute_raw_blueprint(&sql).await?;
690            }
691            MigrationOp::CreateEnum { name, values } => {
692                apply_create_enum(ctx, name, values).await?;
693            }
694            MigrationOp::AlterEnum { name, actions } => {
695                apply_alter_enum(ctx, name, actions).await?;
696            }
697            MigrationOp::DropEnum { name } => {
698                apply_drop_enum(ctx, name).await?;
699            }
700            MigrationOp::RenameEnum { from, to } => {
701                apply_rename_enum(ctx, from, to).await?;
702            }
703        }
704    }
705    Ok(())
706}
707
708fn default_index_name(table: &str, columns: &[String]) -> String {
709    format!("{table}_{}_idx", columns.join("_"))
710}
711
712async fn resolve_table_items(
713    ctx: &mut MigrationContext<'_>,
714    path: &std::path::Path,
715    items: &[TableItem],
716) -> Result<Vec<TableItem>, MigrationError> {
717    let mut resolved = Vec::with_capacity(items.len());
718    for item in items {
719        match item {
720            TableItem::Column(field) if field.reference.is_some() => {
721                resolved.push(TableItem::Column(
722                    resolve_reference_field(ctx, path, field).await?,
723                ));
724            }
725            _ => resolved.push(item.clone()),
726        }
727    }
728    Ok(resolved)
729}
730
731async fn resolve_alter_actions(
732    ctx: &mut MigrationContext<'_>,
733    path: &std::path::Path,
734    actions: &[AlterAction],
735) -> Result<Vec<AlterAction>, MigrationError> {
736    let mut resolved = Vec::with_capacity(actions.len());
737    for action in actions {
738        if let AlterAction::AddColumn(field) = action
739            && field.reference.is_some()
740        {
741            resolved.push(AlterAction::AddColumn(
742                resolve_reference_field(ctx, path, field).await?,
743            ));
744            continue;
745        }
746        resolved.push(action.clone());
747    }
748    Ok(resolved)
749}
750
751async fn resolve_reference_field(
752    ctx: &mut MigrationContext<'_>,
753    path: &std::path::Path,
754    field: &Field,
755) -> Result<Field, MigrationError> {
756    let mut resolved = field.clone();
757    let reference =
758        resolved
759            .reference
760            .as_ref()
761            .ok_or_else(|| MigrationError::InvalidMigrationFile {
762                path: path.to_path_buf(),
763                message: format!("missing reference metadata for `{}`", field.name),
764            })?;
765
766    if matches!(resolved.ty, FieldType::Implicit) {
767        let target_column = reference.column.as_deref().unwrap_or("id");
768        resolved.ty = map_schema_type_to_field_type(
769            ctx.column_type(&reference.table, target_column).await?,
770            path,
771            field,
772        )?;
773    }
774
775    Ok(resolved)
776}
777
778fn map_schema_type_to_field_type(
779    ty: crate::ColumnType,
780    _path: &std::path::Path,
781    _field: &Field,
782) -> Result<FieldType, MigrationError> {
783    Ok(match ty {
784        crate::ColumnType::Integer => FieldType::Integer,
785        crate::ColumnType::BigInt => FieldType::BigInt,
786        crate::ColumnType::Bool => FieldType::Boolean,
787        crate::ColumnType::Char(len) => FieldType::Varchar(len),
788        crate::ColumnType::Varchar(len) => FieldType::Varchar(len),
789        crate::ColumnType::Text => FieldType::Text,
790        crate::ColumnType::Date => FieldType::Date,
791        crate::ColumnType::Time => FieldType::Time,
792        crate::ColumnType::DateTime => FieldType::DateTime,
793        crate::ColumnType::Timestamp => FieldType::Timestamp,
794        crate::ColumnType::Decimal(precision, scale) => FieldType::Decimal(precision, scale),
795        crate::ColumnType::Float => FieldType::Float,
796        crate::ColumnType::Double => FieldType::Double,
797        crate::ColumnType::Json => FieldType::Json,
798        crate::ColumnType::Uuid => FieldType::Uuid,
799        crate::ColumnType::Custom(name) => FieldType::Custom(name),
800    })
801}
802
803fn apply_create_field(table: &mut crate::TableBlueprint, field: &Field) {
804    if let Some(reference) = &field.reference {
805        apply_reference_field(table, field, reference);
806        return;
807    }
808
809    match &field.ty {
810        FieldType::Implicit if field.primary && field.name == "id" => table.id(),
811        FieldType::Id => table.id(),
812        FieldType::String => {
813            let _ = apply_field_attrs(table.string(&field.name), field);
814        }
815        FieldType::Varchar(length) => {
816            let _ = apply_field_attrs(table.varchar(&field.name, *length), field);
817        }
818        FieldType::Text => {
819            let _ = apply_field_attrs(table.text(&field.name), field);
820        }
821        FieldType::Integer => {
822            let _ = apply_field_attrs(table.integer(&field.name), field);
823        }
824        FieldType::BigInt => {
825            let _ = apply_field_attrs(table.bigint(&field.name), field);
826        }
827        FieldType::Boolean => {
828            let _ = apply_field_attrs(table.boolean(&field.name), field);
829        }
830        FieldType::Date => {
831            let _ = apply_field_attrs(table.date(&field.name), field);
832        }
833        FieldType::Time => {
834            let _ = apply_field_attrs(table.time(&field.name), field);
835        }
836        FieldType::DateTime => {
837            let _ = apply_field_attrs(table.datetime(&field.name), field);
838        }
839        FieldType::Timestamp | FieldType::TimestampTz => {
840            let _ = apply_field_attrs(table.timestamp(&field.name), field);
841        }
842        FieldType::Decimal(precision, scale) => {
843            let _ = apply_field_attrs(table.decimal(&field.name, *precision, *scale), field);
844        }
845        FieldType::Float => {
846            let _ = apply_field_attrs(table.float(&field.name), field);
847        }
848        FieldType::Double => {
849            let _ = apply_field_attrs(table.double(&field.name), field);
850        }
851        FieldType::Json => {
852            let _ = apply_field_attrs(table.json(&field.name), field);
853        }
854        FieldType::Uuid => {
855            let _ = apply_field_attrs(table.uuid(&field.name), field);
856        }
857        FieldType::RememberToken => {
858            if field.name == "remember_token" {
859                let _ = apply_field_attrs(table.remember_token(), field);
860            } else {
861                let _ = apply_field_attrs(table.string(&field.name), field);
862            }
863        }
864        FieldType::Custom(name) => {
865            let _ = apply_field_attrs(
866                table.custom(&field.name, crate::ColumnType::Custom(name.clone())),
867                field,
868            );
869        }
870        FieldType::Implicit => {}
871    }
872}
873
874fn apply_alter_field(table: &mut crate::AlterTableBlueprint, field: &Field) {
875    match &field.ty {
876        FieldType::Implicit | FieldType::Id => {}
877        FieldType::String => {
878            let _ = apply_field_attrs(table.string(&field.name), field);
879        }
880        FieldType::Varchar(length) => {
881            let _ = apply_field_attrs(table.varchar(&field.name, *length), field);
882        }
883        FieldType::Text => {
884            let _ = apply_field_attrs(table.text(&field.name), field);
885        }
886        FieldType::Integer => {
887            let _ = apply_field_attrs(table.integer(&field.name), field);
888        }
889        FieldType::BigInt => {
890            let _ = apply_field_attrs(table.bigint(&field.name), field);
891        }
892        FieldType::Boolean => {
893            let _ = apply_field_attrs(table.boolean(&field.name), field);
894        }
895        FieldType::Date => {
896            let _ = apply_field_attrs(table.date(&field.name), field);
897        }
898        FieldType::Time => {
899            let _ = apply_field_attrs(table.time(&field.name), field);
900        }
901        FieldType::DateTime => {
902            let _ = apply_field_attrs(table.datetime(&field.name), field);
903        }
904        FieldType::Timestamp | FieldType::TimestampTz => {
905            let _ = apply_field_attrs(table.timestamp(&field.name), field);
906        }
907        FieldType::Decimal(precision, scale) => {
908            let _ = apply_field_attrs(table.decimal(&field.name, *precision, *scale), field);
909        }
910        FieldType::Float => {
911            let _ = apply_field_attrs(table.float(&field.name), field);
912        }
913        FieldType::Double => {
914            let _ = apply_field_attrs(table.double(&field.name), field);
915        }
916        FieldType::Json => {
917            let _ = apply_field_attrs(table.json(&field.name), field);
918        }
919        FieldType::Uuid => {
920            let _ = apply_field_attrs(table.uuid(&field.name), field);
921        }
922        FieldType::RememberToken => {
923            let _ = apply_field_attrs(table.string(&field.name), field);
924        }
925        FieldType::Custom(name) => {
926            let _ = apply_field_attrs(
927                table.custom(&field.name, crate::ColumnType::Custom(name.clone())),
928                field,
929            );
930        }
931    }
932}
933
934fn apply_reference_field(table: &mut crate::TableBlueprint, field: &Field, reference: &Reference) {
935    let mut builder = apply_field_attrs(
936        table.foreign(&field.name, relation_column_type(field)),
937        field,
938    );
939    if let Some(column) = &reference.column {
940        builder = builder.references_column(&reference.table, column);
941    } else {
942        builder = builder.references(&reference.table);
943    }
944    if field.index {
945        builder = builder.index();
946    }
947    if let Some(action) = reference.on_delete {
948        builder = match action {
949            ReferenceAction::Cascade => builder.cascade_on_delete(),
950            ReferenceAction::Restrict => builder.restrict_on_delete(),
951            ReferenceAction::SetNull => builder.null_on_delete(),
952            ReferenceAction::NoAction => builder.no_action_on_delete(),
953        };
954    } else {
955        builder = builder.restrict_on_delete();
956    }
957    if let Some(action) = reference.on_update {
958        builder = match action {
959            ReferenceAction::Cascade => builder.cascade_on_update(),
960            ReferenceAction::Restrict => builder.restrict_on_update(),
961            ReferenceAction::SetNull => builder.null_on_update(),
962            ReferenceAction::NoAction => builder.no_action_on_update(),
963        };
964    }
965    let _ = builder;
966}
967
968fn relation_column_type(field: &Field) -> crate::ColumnType {
969    match &field.ty {
970        FieldType::String => crate::ColumnType::Varchar(255),
971        FieldType::Text => crate::ColumnType::Text,
972        FieldType::Integer => crate::ColumnType::Integer,
973        FieldType::BigInt => crate::ColumnType::BigInt,
974        FieldType::Boolean => crate::ColumnType::Bool,
975        FieldType::Date => crate::ColumnType::Date,
976        FieldType::Time => crate::ColumnType::Time,
977        FieldType::DateTime => crate::ColumnType::DateTime,
978        FieldType::Timestamp | FieldType::TimestampTz => crate::ColumnType::Timestamp,
979        FieldType::Decimal(precision, scale) => crate::ColumnType::Decimal(*precision, *scale),
980        FieldType::Float => crate::ColumnType::Float,
981        FieldType::Double => crate::ColumnType::Double,
982        FieldType::Json => crate::ColumnType::Json,
983        FieldType::Uuid => crate::ColumnType::Uuid,
984        FieldType::Varchar(length) => crate::ColumnType::Varchar(*length),
985        FieldType::Id => crate::ColumnType::BigInt,
986        FieldType::RememberToken => crate::ColumnType::Varchar(100),
987        FieldType::Custom(name) => crate::ColumnType::Custom(name.clone()),
988        FieldType::Implicit => crate::ColumnType::BigInt,
989    }
990}
991
992async fn add_relation_column(
993    ctx: &mut MigrationContext<'_>,
994    table_name: &str,
995    field: &Field,
996) -> Result<(), MigrationError> {
997    let Some(reference) = &field.reference else {
998        return Ok(());
999    };
1000
1001    let sql = format!(
1002        "alter table {} add column {};",
1003        ctx.dialect().quote_ident(table_name),
1004        render_inline_relation_column(ctx.dialect(), field, reference)
1005    );
1006    ctx.execute_raw_blueprint(&sql).await?;
1007
1008    if field.index {
1009        let index = crate::IndexBlueprint::new(
1010            &default_index_name(table_name, std::slice::from_ref(&field.name)),
1011            table_name,
1012            [field.name.as_str()],
1013        );
1014        ctx.execute_raw_blueprint(&index.create_sql(ctx.dialect()))
1015            .await?;
1016    }
1017
1018    Ok(())
1019}
1020
1021fn render_inline_relation_column(
1022    dialect: crate::SchemaDialect,
1023    field: &Field,
1024    reference: &Reference,
1025) -> String {
1026    let column = crate::schema::blueprint::ColumnDef {
1027        name: field.name.clone(),
1028        ty: relation_column_type(field),
1029        nullable: matches!(field.nullable, Nullability::Nullable),
1030        primary_key: false,
1031        auto_increment: false,
1032        unique: field.unique,
1033        default_raw: field.default.clone(),
1034    };
1035    let mut out = crate::schema::render::render_column(dialect, &column);
1036    out.push_str(" references ");
1037    out.push_str(&dialect.quote_ident(&reference.table));
1038    out.push('(');
1039    out.push_str(&dialect.quote_ident(reference.column.as_deref().unwrap_or("id")));
1040    out.push(')');
1041    match reference.on_delete.unwrap_or(ReferenceAction::Restrict) {
1042        ReferenceAction::Cascade => out.push_str(" on delete cascade"),
1043        ReferenceAction::Restrict => out.push_str(" on delete restrict"),
1044        ReferenceAction::SetNull => out.push_str(" on delete set null"),
1045        ReferenceAction::NoAction => out.push_str(" on delete no action"),
1046    }
1047    if let Some(action) = reference.on_update {
1048        match action {
1049            ReferenceAction::Cascade => out.push_str(" on update cascade"),
1050            ReferenceAction::Restrict => out.push_str(" on update restrict"),
1051            ReferenceAction::SetNull => out.push_str(" on update set null"),
1052            ReferenceAction::NoAction => out.push_str(" on update no action"),
1053        }
1054    }
1055    out
1056}
1057
1058async fn apply_create_enum(
1059    ctx: &mut MigrationContext<'_>,
1060    name: &str,
1061    values: &[String],
1062) -> Result<(), MigrationError> {
1063    if matches!(ctx.dialect(), crate::SchemaDialect::Postgres) {
1064        let values = values
1065            .iter()
1066            .map(|value| format!("'{}'", value.replace('\'', "''")))
1067            .collect::<Vec<_>>()
1068            .join(", ");
1069        let sql = format!(
1070            "create type {} as enum ({values});",
1071            ctx.dialect().quote_ident(name)
1072        );
1073        ctx.execute_raw_blueprint(&sql).await?;
1074    }
1075    Ok(())
1076}
1077
1078async fn apply_alter_enum(
1079    ctx: &mut MigrationContext<'_>,
1080    name: &str,
1081    actions: &[AlterEnumAction],
1082) -> Result<(), MigrationError> {
1083    if matches!(ctx.dialect(), crate::SchemaDialect::Postgres) {
1084        for action in actions {
1085            match action {
1086                AlterEnumAction::AddValue(value) => {
1087                    let sql = format!(
1088                        "alter type {} add value if not exists '{}';",
1089                        ctx.dialect().quote_ident(name),
1090                        value.replace('\'', "''")
1091                    );
1092                    ctx.execute_raw_blueprint(&sql).await?;
1093                }
1094            }
1095        }
1096    }
1097    Ok(())
1098}
1099
1100async fn apply_drop_enum(ctx: &mut MigrationContext<'_>, name: &str) -> Result<(), MigrationError> {
1101    if matches!(ctx.dialect(), crate::SchemaDialect::Postgres) {
1102        let sql = format!("drop type if exists {};", ctx.dialect().quote_ident(name));
1103        ctx.execute_raw_blueprint(&sql).await?;
1104    }
1105    Ok(())
1106}
1107
1108async fn apply_rename_enum(
1109    ctx: &mut MigrationContext<'_>,
1110    from: &str,
1111    to: &str,
1112) -> Result<(), MigrationError> {
1113    if matches!(ctx.dialect(), crate::SchemaDialect::Postgres) {
1114        let sql = format!(
1115            "alter type {} rename to {};",
1116            ctx.dialect().quote_ident(from),
1117            ctx.dialect().quote_ident(to)
1118        );
1119        ctx.execute_raw_blueprint(&sql).await?;
1120    }
1121    Ok(())
1122}
1123
1124impl From<quex::Pool> for Migrator {
1125    fn from(pool: quex::Pool) -> Self {
1126        #[allow(unreachable_patterns)]
1127        let pool = match pool.driver() {
1128            #[cfg(feature = "sqlite")]
1129            quex::Driver::Sqlite => MigrationPool::Sqlite(pool),
1130            #[cfg(feature = "postgres")]
1131            quex::Driver::Pgsql => MigrationPool::Postgres(pool),
1132            #[cfg(feature = "mariadb")]
1133            quex::Driver::Mysql => MigrationPool::Mariadb(pool),
1134            other => panic!("unsupported pool driver for migrations: {other:?}"),
1135        };
1136        Self {
1137            pool,
1138            migration_dir: default_migration_dir(),
1139        }
1140    }
1141}
1142
1143impl From<&quex::Pool> for Migrator {
1144    fn from(pool: &quex::Pool) -> Self {
1145        pool.clone().into()
1146    }
1147}