1use std::{collections::HashMap, path::PathBuf};
2
3use quex::{self, Driver, FromRow, Pool, Row};
4
5use super::pool::MigrationPool;
6use crate::{
7 AlterColumnBuilder, BlueprintExecutor, ColumnBuilder, MigrationContext, MigrationError,
8 MigrationReport, ResetReport,
9 files::{
10 AlterAction, AlterEnumAction, Field, FieldType, MigrationFile, MigrationOp, Nullability,
11 Reference, ReferenceAction, TableItem, default_migration_dir, load_migrations,
12 },
13};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum CliCommand {
17 Run,
18 Refresh,
19 RollbackLastBatch,
20 RollbackSteps(usize),
21 Reset,
22 Validate,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct CliOptions {
27 pub database_url: String,
28 pub command: CliCommand,
29 pub migration_dir: Option<PathBuf>,
30}
31
32pub async fn run_cli_command(options: &CliOptions) -> Result<MigrationReport, MigrationError> {
33 let migrator = Migrator::connect(&options.database_url).await?.path(
34 options
35 .migration_dir
36 .clone()
37 .unwrap_or_else(default_migration_dir),
38 );
39 match options.command {
40 CliCommand::Run => migrator.run().await,
41 CliCommand::Refresh => migrator.refresh().await,
42 CliCommand::RollbackLastBatch => migrator.rollback_last_batch().await,
43 CliCommand::RollbackSteps(steps) => migrator.rollback_steps(steps).await,
44 CliCommand::Reset => {
45 let report = migrator.reset().await?;
46 Ok(MigrationReport {
47 batch: None,
48 applied: Vec::new(),
49 rolled_back: report.rolled_back,
50 })
51 }
52 CliCommand::Validate => {
53 migrator.validate()?;
54 Ok(MigrationReport::default())
55 }
56 }
57}
58
59struct BatchValue {
60 batch: Option<i64>,
61}
62
63impl FromRow for BatchValue {
64 fn from_row(row: &Row) -> quex::Result<Self> {
65 Ok(Self {
66 batch: row.get("batch")?,
67 })
68 }
69}
70
71struct ExistsValue {
72 value: i64,
73}
74
75impl FromRow for ExistsValue {
76 fn from_row(row: &Row) -> quex::Result<Self> {
77 Ok(Self {
78 value: row.get("value")?,
79 })
80 }
81}
82
83struct AppliedMigrationRow {
84 id: i64,
85 checksum: String,
86 status: String,
87}
88
89impl FromRow for AppliedMigrationRow {
90 fn from_row(row: &Row) -> quex::Result<Self> {
91 Ok(Self {
92 id: row.get("id")?,
93 checksum: row.get("checksum")?,
94 status: row.get("status")?,
95 })
96 }
97}
98
99struct AppliedMigrationId {
100 id: i64,
101}
102
103impl FromRow for AppliedMigrationId {
104 fn from_row(row: &Row) -> quex::Result<Self> {
105 Ok(Self { id: row.get("id")? })
106 }
107}
108
109pub struct Migrator {
110 pub(crate) pool: MigrationPool,
111 migration_dir: PathBuf,
112}
113
114impl Migrator {
115 pub fn new(pool: impl Into<Self>) -> Self {
116 pool.into()
117 }
118
119 pub fn path(mut self, path: impl Into<PathBuf>) -> Self {
120 self.migration_dir = path.into();
121 self
122 }
123
124 pub async fn connect(database_url: &str) -> Result<Self, MigrationError> {
125 let pool = Pool::connect(database_url)?.max_size(1).build().await?;
126 let migrator = match pool.driver() {
127 Driver::Sqlite => {
128 #[cfg(feature = "sqlite")]
129 {
130 Ok::<Self, MigrationError>(Self {
131 pool: MigrationPool::Sqlite(pool),
132 migration_dir: default_migration_dir(),
133 })
134 }
135 #[cfg(not(feature = "sqlite"))]
136 {
137 Err(MigrationError::BackendNotEnabled("sqlite"))
138 }
139 }
140 Driver::Pgsql => {
141 #[cfg(feature = "postgres")]
142 {
143 Ok::<Self, MigrationError>(Self {
144 pool: MigrationPool::Postgres(pool),
145 migration_dir: default_migration_dir(),
146 })
147 }
148 #[cfg(not(feature = "postgres"))]
149 {
150 Err(MigrationError::BackendNotEnabled("postgres"))
151 }
152 }
153 Driver::Mysql => {
154 #[cfg(feature = "mariadb")]
155 {
156 Ok::<Self, MigrationError>(Self {
157 pool: MigrationPool::Mariadb(pool),
158 migration_dir: default_migration_dir(),
159 })
160 }
161 #[cfg(not(feature = "mariadb"))]
162 {
163 Err(MigrationError::BackendNotEnabled("mariadb"))
164 }
165 }
166 }?;
167 migration_trace!(
168 backend = migrator.pool.backend_name(),
169 "connecting migrator"
170 );
171 Ok(migrator)
172 }
173
174 pub fn validate(&self) -> Result<(), MigrationError> {
175 let _ = self.load_migrations()?;
176 Ok(())
177 }
178
179 pub async fn reset(&self) -> Result<ResetReport, MigrationError> {
180 let mut rolled_back = Vec::new();
181 let mut batches = Vec::new();
182
183 loop {
184 let batch_report = self.rollback_last_batch().await?;
185 if batch_report.rolled_back.is_empty() {
186 break;
187 }
188
189 rolled_back.extend(batch_report.rolled_back);
190 if let Some(batch) = batch_report.batch {
191 batches.push(batch);
192 }
193 }
194
195 Ok(ResetReport {
196 rolled_back,
197 batches,
198 })
199 }
200
201 pub async fn refresh(&self) -> Result<MigrationReport, MigrationError> {
202 let reset_report = self.reset().await?;
203 let up_report = self.run().await?;
204
205 Ok(MigrationReport {
206 batch: up_report.batch,
207 applied: up_report.applied,
208 rolled_back: reset_report.rolled_back,
209 })
210 }
211
212 pub async fn run(&self) -> Result<MigrationReport, MigrationError> {
213 migration_trace!(backend = self.pool.backend_name(), "starting migration run");
214 self.ensure_table().await?;
215
216 let migrations = self.load_migrations()?;
217 let mut applied = self.applied_migrations().await?;
218 let mut retryable_failed_ids = Vec::new();
219
220 for migration in &migrations {
221 if let Some(row) = applied.get(&migration.id) {
222 if row.status == "failed" {
223 if self.pool.supports_clean_failed_retry() {
224 retryable_failed_ids.push(migration.id);
225 continue;
226 }
227 return Err(MigrationError::FailedMigration { id: migration.id });
228 }
229 if row.checksum != migration.checksum {
230 return Err(MigrationError::ChecksumMismatch { id: migration.id });
231 }
232 }
233 }
234
235 for id in retryable_failed_ids {
236 self.delete_applied(id).await?;
237 applied.remove(&id);
238 }
239
240 let pending: Vec<_> = migrations
241 .into_iter()
242 .filter(|migration| !applied.contains_key(&migration.id))
243 .collect();
244
245 if pending.is_empty() {
246 return Ok(MigrationReport::default());
247 }
248
249 let batch = self.next_batch().await?;
250 let mut report = MigrationReport {
251 batch: Some(batch),
252 applied: Vec::with_capacity(pending.len()),
253 rolled_back: Vec::new(),
254 };
255
256 for migration in pending {
257 self.mark_running(&migration, batch).await?;
258 match self.apply(&migration, &migration.up).await {
259 Ok(()) => {
260 self.mark_applied(&migration.id).await?;
261 report.applied.push(migration.id);
262 }
263 Err(error) => {
264 self.mark_failed(migration.id).await?;
265 return Err(error);
266 }
267 }
268 }
269
270 Ok(report)
271 }
272
273 pub async fn rollback_last_batch(&self) -> Result<MigrationReport, MigrationError> {
274 self.ensure_table().await?;
275
276 let Some(batch) = self.last_batch().await? else {
277 return Ok(MigrationReport::default());
278 };
279
280 let ids = self.ids_for_batch(batch).await?;
281 self.rollback_ids(ids, Some(batch)).await
282 }
283
284 pub async fn rollback_steps(&self, steps: usize) -> Result<MigrationReport, MigrationError> {
285 self.ensure_table().await?;
286 if steps == 0 {
287 return Ok(MigrationReport::default());
288 }
289
290 let ids = self.latest_ids(steps).await?;
291 self.rollback_ids(ids, None).await
292 }
293
294 fn load_migrations(&self) -> Result<Vec<MigrationFile>, MigrationError> {
295 load_migrations(&self.migration_dir)
296 }
297
298 async fn rollback_ids(
299 &self,
300 ids: Vec<u64>,
301 batch: Option<u64>,
302 ) -> Result<MigrationReport, MigrationError> {
303 if ids.is_empty() {
304 return Ok(MigrationReport::default());
305 }
306
307 let migrations = self
308 .load_migrations()?
309 .into_iter()
310 .map(|migration| (migration.id, migration))
311 .collect::<HashMap<_, _>>();
312
313 let mut report = MigrationReport {
314 batch,
315 applied: Vec::new(),
316 rolled_back: Vec::with_capacity(ids.len()),
317 };
318
319 for id in ids {
320 let migration = migrations
321 .get(&id)
322 .ok_or(MigrationError::MissingMigration(id))?;
323 self.apply(migration, &migration.down).await?;
324 self.delete_applied(id).await?;
325 report.rolled_back.push(id);
326 }
327
328 Ok(report)
329 }
330
331 async fn apply(
332 &self,
333 migration: &MigrationFile,
334 ops: &[MigrationOp],
335 ) -> Result<(), MigrationError> {
336 let mut tx = self.pool.pool().begin().await?;
337 let result = {
338 let mut ctx = self.new_transaction_context(&mut tx);
339 apply_operations(&mut ctx, &migration.path, ops).await
340 };
341 match result {
342 Ok(()) => tx.commit().await.map_err(Into::into),
343 Err(error) => {
344 let _ = tx.rollback().await;
345 Err(error)
346 }
347 }
348 }
349
350 fn new_transaction_context<'a>(
351 &self,
352 tx: &'a mut quex::PoolTransaction,
353 ) -> MigrationContext<'a> {
354 match &self.pool {
355 #[cfg(feature = "sqlite")]
356 MigrationPool::Sqlite(_) => {
357 MigrationContext::Sqlite(crate::SqliteMigrationContext::from_transaction(tx))
358 }
359 #[cfg(feature = "postgres")]
360 MigrationPool::Postgres(_) => {
361 MigrationContext::Postgres(crate::PostgresMigrationContext::from_transaction(tx))
362 }
363 #[cfg(feature = "mariadb")]
364 MigrationPool::Mariadb(_) => {
365 MigrationContext::Mariadb(crate::MariadbMigrationContext::from_transaction(tx))
366 }
367 #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
368 MigrationPool::Disabled => MigrationContext::Disabled(std::marker::PhantomData),
369 }
370 }
371
372 async fn ensure_table(&self) -> Result<(), MigrationError> {
373 let exists_sql = match &self.pool {
374 #[cfg(feature = "sqlite")]
375 MigrationPool::Sqlite(_) => {
376 "select 1 as value from sqlite_master where type = 'table' and name = 'migrations' limit 1"
377 }
378 #[cfg(feature = "postgres")]
379 MigrationPool::Postgres(_) => {
380 "select 1 as value from pg_catalog.pg_class c join pg_catalog.pg_namespace n on n.oid = c.relnamespace where c.relkind = 'r' and c.relname = 'migrations' and n.nspname = current_schema() limit 1"
381 }
382 #[cfg(feature = "mariadb")]
383 MigrationPool::Mariadb(_) => {
384 "select 1 as value from information_schema.tables where table_schema = database() and table_name = 'migrations' limit 1"
385 }
386 #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
387 MigrationPool::Disabled => return Err(MigrationError::BackendNotEnabled("no backend")),
388 };
389 let exists = quex::query(exists_sql)
390 .optional::<ExistsValue>(self.pool.pool())
391 .await?
392 .map(|row| row.value != 0)
393 .unwrap_or(false);
394 if exists {
395 return Ok(());
396 }
397
398 let sql = match &self.pool {
399 #[cfg(feature = "sqlite")]
400 MigrationPool::Sqlite(_) => {
401 "create table if not exists migrations (id integer primary key not null, name text not null, checksum text not null, batch integer not null, status text not null, started_at text not null default current_timestamp, finished_at text null)"
402 }
403 #[cfg(feature = "postgres")]
404 MigrationPool::Postgres(_) => {
405 "create table if not exists migrations (id bigint primary key not null, name text not null, checksum text not null, batch bigint not null, status text not null, started_at timestamptz not null default current_timestamp, finished_at timestamptz null)"
406 }
407 #[cfg(feature = "mariadb")]
408 MigrationPool::Mariadb(_) => {
409 "create table if not exists migrations (id bigint primary key not null, name varchar(255) not null, checksum varchar(255) not null, batch bigint not null, status varchar(32) not null, started_at timestamp not null default current_timestamp, finished_at timestamp null)"
410 }
411 #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
412 MigrationPool::Disabled => return Err(MigrationError::BackendNotEnabled("no backend")),
413 };
414 quex::query(sql).execute(self.pool.pool()).await?;
415 Ok(())
416 }
417
418 async fn applied_migrations(
419 &self,
420 ) -> Result<HashMap<u64, AppliedMigrationRow>, MigrationError> {
421 let rows = quex::query("select id, checksum, status from migrations order by id")
422 .all::<AppliedMigrationRow>(self.pool.pool())
423 .await?;
424 Ok(rows
425 .into_iter()
426 .map(|row| (row.id as u64, row))
427 .collect::<HashMap<_, _>>())
428 }
429
430 async fn next_batch(&self) -> Result<u64, MigrationError> {
431 Ok(self.max_batch().await?.unwrap_or(0) + 1)
432 }
433
434 async fn last_batch(&self) -> Result<Option<u64>, MigrationError> {
435 self.max_batch().await
436 }
437
438 async fn max_batch(&self) -> Result<Option<u64>, MigrationError> {
439 let row =
440 quex::query("select max(batch) as batch from migrations where status = 'applied'")
441 .one::<BatchValue>(self.pool.pool())
442 .await?;
443 Ok(row.batch.map(|value| value as u64))
444 }
445
446 async fn ids_for_batch(&self, batch: u64) -> Result<Vec<u64>, MigrationError> {
447 let ids = quex::query(
448 "select id from migrations where batch = ? and status = 'applied' order by id desc",
449 )
450 .bind(batch as i64)
451 .all::<AppliedMigrationId>(self.pool.pool())
452 .await?;
453 Ok(ids.into_iter().map(|value| value.id as u64).collect())
454 }
455
456 async fn latest_ids(&self, steps: usize) -> Result<Vec<u64>, MigrationError> {
457 let ids = quex::query("select id from migrations where status = 'applied' order by batch desc, id desc limit ?")
458 .bind(steps as i64)
459 .all::<AppliedMigrationId>(self.pool.pool())
460 .await?;
461 Ok(ids.into_iter().map(|value| value.id as u64).collect())
462 }
463
464 async fn mark_running(
465 &self,
466 migration: &MigrationFile,
467 batch: u64,
468 ) -> Result<(), MigrationError> {
469 quex::query("delete from migrations where id = ?")
470 .bind(migration.id as i64)
471 .execute(self.pool.pool())
472 .await?;
473 quex::query(
474 "insert into migrations(id, name, checksum, batch, status) values(?, ?, ?, ?, ?)",
475 )
476 .bind(migration.id as i64)
477 .bind(&migration.name)
478 .bind(&migration.checksum)
479 .bind(batch as i64)
480 .bind("running")
481 .execute(self.pool.pool())
482 .await?;
483 Ok(())
484 }
485
486 async fn mark_applied(&self, id: &u64) -> Result<(), MigrationError> {
487 quex::query(
488 "update migrations set status = ?, finished_at = current_timestamp where id = ?",
489 )
490 .bind("applied")
491 .bind(*id as i64)
492 .execute(self.pool.pool())
493 .await?;
494 Ok(())
495 }
496
497 async fn mark_failed(&self, id: u64) -> Result<(), MigrationError> {
498 quex::query(
499 "update migrations set status = ?, finished_at = current_timestamp where id = ?",
500 )
501 .bind("failed")
502 .bind(id as i64)
503 .execute(self.pool.pool())
504 .await?;
505 Ok(())
506 }
507
508 async fn delete_applied(&self, id: u64) -> Result<(), MigrationError> {
509 quex::query("delete from migrations where id = ?")
510 .bind(id as i64)
511 .execute(self.pool.pool())
512 .await?;
513 Ok(())
514 }
515}
516
517trait ColumnAttrs: Sized {
518 fn nullable(self) -> Self;
519 fn unique(self) -> Self;
520 fn default_raw(self, value: &str) -> Self;
521}
522
523impl<'a> ColumnAttrs for ColumnBuilder<'a> {
524 fn nullable(self) -> Self {
525 Self::nullable(self)
526 }
527
528 fn unique(self) -> Self {
529 Self::unique(self)
530 }
531
532 fn default_raw(self, value: &str) -> Self {
533 Self::default_raw(self, value)
534 }
535}
536
537impl<'a> ColumnAttrs for AlterColumnBuilder<'a> {
538 fn nullable(self) -> Self {
539 Self::nullable(self)
540 }
541
542 fn unique(self) -> Self {
543 Self::unique(self)
544 }
545
546 fn default_raw(self, value: &str) -> Self {
547 Self::default_raw(self, value)
548 }
549}
550
551impl<'a> ColumnAttrs for crate::ForeignKeyBuilder<'a> {
552 fn nullable(self) -> Self {
553 Self::nullable(self)
554 }
555
556 fn unique(self) -> Self {
557 Self::unique(self)
558 }
559
560 fn default_raw(self, value: &str) -> Self {
561 Self::default_raw(self, value)
562 }
563}
564
565fn apply_field_attrs<B>(builder: B, field: &Field) -> B
566where
567 B: ColumnAttrs,
568{
569 let mut builder = builder;
570 if matches!(field.nullable, Nullability::Nullable) {
571 builder = builder.nullable();
572 }
573 if field.unique {
574 builder = builder.unique();
575 }
576 if let Some(default) = &field.default {
577 builder = builder.default_raw(default);
578 }
579 builder
580}
581
582async fn apply_operations(
583 ctx: &mut MigrationContext<'_>,
584 path: &std::path::Path,
585 ops: &[MigrationOp],
586) -> Result<(), MigrationError> {
587 for op in ops {
588 match op {
589 MigrationOp::Sql { sql } => {
590 ctx.execute_raw(sql).await?;
591 }
592 MigrationOp::Backfill { sql } => {
593 ctx.execute_raw(sql).await?;
594 }
595 MigrationOp::CreateTable { name, items } => {
596 let items = resolve_table_items(ctx, path, items).await?;
597 ctx.create(name, |table| {
598 for item in &items {
599 match item {
600 TableItem::Column(field) => apply_create_field(table, field),
601 TableItem::Index {
602 name: index_name,
603 columns,
604 } => table.index(
605 index_name
606 .as_deref()
607 .unwrap_or(&default_index_name(name, columns)),
608 columns.clone(),
609 ),
610 TableItem::Unique(columns) => table.unique(columns.clone()),
611 TableItem::ConstraintUnique { name, columns } => {
612 table.unique_named(name, columns.clone())
613 }
614 TableItem::Primary(columns) => table.primary(columns.clone()),
615 TableItem::Timestamps => table.timestamps(),
616 TableItem::SoftDeletes => {
617 table.timestamp("deleted_at").nullable();
618 }
619 }
620 }
621 })
622 .await?;
623 }
624 MigrationOp::AlterTable { name, actions } => {
625 let actions = resolve_alter_actions(ctx, path, actions).await?;
626 for action in &actions {
627 if let AlterAction::DropIndex(index_name) = action {
628 let index = crate::IndexBlueprint::named(index_name);
629 ctx.execute_raw_blueprint(&index.drop_sql(ctx.dialect()))
630 .await?;
631 }
632 }
633 ctx.alter_table(name, |table| {
634 for action in &actions {
635 match action {
636 AlterAction::AddColumn(field) if field.reference.is_none() => {
637 apply_alter_field(table, field)
638 }
639 AlterAction::DropColumn(name) => table.drop_column(name),
640 AlterAction::RenameColumn { .. } => {}
641 AlterAction::AddIndex { .. } => {}
642 AlterAction::DropIndex(_) | AlterAction::AddColumn(_) => {}
643 }
644 }
645 })
646 .await?;
647 for action in &actions {
648 match action {
649 AlterAction::RenameColumn { from, to } => {
650 let sql = format!(
651 "alter table {} rename column {} to {};",
652 ctx.dialect().quote_ident(name),
653 ctx.dialect().quote_ident(from),
654 ctx.dialect().quote_ident(to)
655 );
656 ctx.execute_raw_blueprint(&sql).await?;
657 }
658 AlterAction::AddIndex {
659 name: index_name,
660 columns,
661 } => {
662 let index = crate::IndexBlueprint::new(
663 index_name
664 .as_deref()
665 .unwrap_or(&default_index_name(name, columns)),
666 name,
667 columns.clone(),
668 );
669 ctx.execute_raw_blueprint(&index.create_sql(ctx.dialect()))
670 .await?;
671 }
672 AlterAction::AddColumn(field) if field.reference.is_some() => {
673 add_relation_column(ctx, name, field).await?;
674 }
675 AlterAction::DropIndex(_) => {}
676 AlterAction::AddColumn(_) | AlterAction::DropColumn(_) => {}
677 }
678 }
679 }
680 MigrationOp::DropTable { name } => {
681 ctx.drop(name).await?;
682 }
683 MigrationOp::RenameTable { from, to } => {
684 let sql = format!(
685 "alter table {} rename to {};",
686 ctx.dialect().quote_ident(from),
687 ctx.dialect().quote_ident(to)
688 );
689 ctx.execute_raw_blueprint(&sql).await?;
690 }
691 MigrationOp::CreateEnum { name, values } => {
692 apply_create_enum(ctx, name, values).await?;
693 }
694 MigrationOp::AlterEnum { name, actions } => {
695 apply_alter_enum(ctx, name, actions).await?;
696 }
697 MigrationOp::DropEnum { name } => {
698 apply_drop_enum(ctx, name).await?;
699 }
700 MigrationOp::RenameEnum { from, to } => {
701 apply_rename_enum(ctx, from, to).await?;
702 }
703 }
704 }
705 Ok(())
706}
707
708fn default_index_name(table: &str, columns: &[String]) -> String {
709 format!("{table}_{}_idx", columns.join("_"))
710}
711
712async fn resolve_table_items(
713 ctx: &mut MigrationContext<'_>,
714 path: &std::path::Path,
715 items: &[TableItem],
716) -> Result<Vec<TableItem>, MigrationError> {
717 let mut resolved = Vec::with_capacity(items.len());
718 for item in items {
719 match item {
720 TableItem::Column(field) if field.reference.is_some() => {
721 resolved.push(TableItem::Column(
722 resolve_reference_field(ctx, path, field).await?,
723 ));
724 }
725 _ => resolved.push(item.clone()),
726 }
727 }
728 Ok(resolved)
729}
730
731async fn resolve_alter_actions(
732 ctx: &mut MigrationContext<'_>,
733 path: &std::path::Path,
734 actions: &[AlterAction],
735) -> Result<Vec<AlterAction>, MigrationError> {
736 let mut resolved = Vec::with_capacity(actions.len());
737 for action in actions {
738 if let AlterAction::AddColumn(field) = action
739 && field.reference.is_some()
740 {
741 resolved.push(AlterAction::AddColumn(
742 resolve_reference_field(ctx, path, field).await?,
743 ));
744 continue;
745 }
746 resolved.push(action.clone());
747 }
748 Ok(resolved)
749}
750
751async fn resolve_reference_field(
752 ctx: &mut MigrationContext<'_>,
753 path: &std::path::Path,
754 field: &Field,
755) -> Result<Field, MigrationError> {
756 let mut resolved = field.clone();
757 let reference =
758 resolved
759 .reference
760 .as_ref()
761 .ok_or_else(|| MigrationError::InvalidMigrationFile {
762 path: path.to_path_buf(),
763 message: format!("missing reference metadata for `{}`", field.name),
764 })?;
765
766 if matches!(resolved.ty, FieldType::Implicit) {
767 let target_column = reference.column.as_deref().unwrap_or("id");
768 resolved.ty = map_schema_type_to_field_type(
769 ctx.column_type(&reference.table, target_column).await?,
770 path,
771 field,
772 )?;
773 }
774
775 Ok(resolved)
776}
777
778fn map_schema_type_to_field_type(
779 ty: crate::ColumnType,
780 _path: &std::path::Path,
781 _field: &Field,
782) -> Result<FieldType, MigrationError> {
783 Ok(match ty {
784 crate::ColumnType::Integer => FieldType::Integer,
785 crate::ColumnType::BigInt => FieldType::BigInt,
786 crate::ColumnType::Bool => FieldType::Boolean,
787 crate::ColumnType::Char(len) => FieldType::Varchar(len),
788 crate::ColumnType::Varchar(len) => FieldType::Varchar(len),
789 crate::ColumnType::Text => FieldType::Text,
790 crate::ColumnType::Date => FieldType::Date,
791 crate::ColumnType::Time => FieldType::Time,
792 crate::ColumnType::DateTime => FieldType::DateTime,
793 crate::ColumnType::Timestamp => FieldType::Timestamp,
794 crate::ColumnType::Decimal(precision, scale) => FieldType::Decimal(precision, scale),
795 crate::ColumnType::Float => FieldType::Float,
796 crate::ColumnType::Double => FieldType::Double,
797 crate::ColumnType::Json => FieldType::Json,
798 crate::ColumnType::Uuid => FieldType::Uuid,
799 crate::ColumnType::Custom(name) => FieldType::Custom(name),
800 })
801}
802
803fn apply_create_field(table: &mut crate::TableBlueprint, field: &Field) {
804 if let Some(reference) = &field.reference {
805 apply_reference_field(table, field, reference);
806 return;
807 }
808
809 match &field.ty {
810 FieldType::Implicit if field.primary && field.name == "id" => table.id(),
811 FieldType::Id => table.id(),
812 FieldType::String => {
813 let _ = apply_field_attrs(table.string(&field.name), field);
814 }
815 FieldType::Varchar(length) => {
816 let _ = apply_field_attrs(table.varchar(&field.name, *length), field);
817 }
818 FieldType::Text => {
819 let _ = apply_field_attrs(table.text(&field.name), field);
820 }
821 FieldType::Integer => {
822 let _ = apply_field_attrs(table.integer(&field.name), field);
823 }
824 FieldType::BigInt => {
825 let _ = apply_field_attrs(table.bigint(&field.name), field);
826 }
827 FieldType::Boolean => {
828 let _ = apply_field_attrs(table.boolean(&field.name), field);
829 }
830 FieldType::Date => {
831 let _ = apply_field_attrs(table.date(&field.name), field);
832 }
833 FieldType::Time => {
834 let _ = apply_field_attrs(table.time(&field.name), field);
835 }
836 FieldType::DateTime => {
837 let _ = apply_field_attrs(table.datetime(&field.name), field);
838 }
839 FieldType::Timestamp | FieldType::TimestampTz => {
840 let _ = apply_field_attrs(table.timestamp(&field.name), field);
841 }
842 FieldType::Decimal(precision, scale) => {
843 let _ = apply_field_attrs(table.decimal(&field.name, *precision, *scale), field);
844 }
845 FieldType::Float => {
846 let _ = apply_field_attrs(table.float(&field.name), field);
847 }
848 FieldType::Double => {
849 let _ = apply_field_attrs(table.double(&field.name), field);
850 }
851 FieldType::Json => {
852 let _ = apply_field_attrs(table.json(&field.name), field);
853 }
854 FieldType::Uuid => {
855 let _ = apply_field_attrs(table.uuid(&field.name), field);
856 }
857 FieldType::RememberToken => {
858 if field.name == "remember_token" {
859 let _ = apply_field_attrs(table.remember_token(), field);
860 } else {
861 let _ = apply_field_attrs(table.string(&field.name), field);
862 }
863 }
864 FieldType::Custom(name) => {
865 let _ = apply_field_attrs(
866 table.custom(&field.name, crate::ColumnType::Custom(name.clone())),
867 field,
868 );
869 }
870 FieldType::Implicit => {}
871 }
872}
873
874fn apply_alter_field(table: &mut crate::AlterTableBlueprint, field: &Field) {
875 match &field.ty {
876 FieldType::Implicit | FieldType::Id => {}
877 FieldType::String => {
878 let _ = apply_field_attrs(table.string(&field.name), field);
879 }
880 FieldType::Varchar(length) => {
881 let _ = apply_field_attrs(table.varchar(&field.name, *length), field);
882 }
883 FieldType::Text => {
884 let _ = apply_field_attrs(table.text(&field.name), field);
885 }
886 FieldType::Integer => {
887 let _ = apply_field_attrs(table.integer(&field.name), field);
888 }
889 FieldType::BigInt => {
890 let _ = apply_field_attrs(table.bigint(&field.name), field);
891 }
892 FieldType::Boolean => {
893 let _ = apply_field_attrs(table.boolean(&field.name), field);
894 }
895 FieldType::Date => {
896 let _ = apply_field_attrs(table.date(&field.name), field);
897 }
898 FieldType::Time => {
899 let _ = apply_field_attrs(table.time(&field.name), field);
900 }
901 FieldType::DateTime => {
902 let _ = apply_field_attrs(table.datetime(&field.name), field);
903 }
904 FieldType::Timestamp | FieldType::TimestampTz => {
905 let _ = apply_field_attrs(table.timestamp(&field.name), field);
906 }
907 FieldType::Decimal(precision, scale) => {
908 let _ = apply_field_attrs(table.decimal(&field.name, *precision, *scale), field);
909 }
910 FieldType::Float => {
911 let _ = apply_field_attrs(table.float(&field.name), field);
912 }
913 FieldType::Double => {
914 let _ = apply_field_attrs(table.double(&field.name), field);
915 }
916 FieldType::Json => {
917 let _ = apply_field_attrs(table.json(&field.name), field);
918 }
919 FieldType::Uuid => {
920 let _ = apply_field_attrs(table.uuid(&field.name), field);
921 }
922 FieldType::RememberToken => {
923 let _ = apply_field_attrs(table.string(&field.name), field);
924 }
925 FieldType::Custom(name) => {
926 let _ = apply_field_attrs(
927 table.custom(&field.name, crate::ColumnType::Custom(name.clone())),
928 field,
929 );
930 }
931 }
932}
933
934fn apply_reference_field(table: &mut crate::TableBlueprint, field: &Field, reference: &Reference) {
935 let mut builder = apply_field_attrs(
936 table.foreign(&field.name, relation_column_type(field)),
937 field,
938 );
939 if let Some(column) = &reference.column {
940 builder = builder.references_column(&reference.table, column);
941 } else {
942 builder = builder.references(&reference.table);
943 }
944 if field.index {
945 builder = builder.index();
946 }
947 if let Some(action) = reference.on_delete {
948 builder = match action {
949 ReferenceAction::Cascade => builder.cascade_on_delete(),
950 ReferenceAction::Restrict => builder.restrict_on_delete(),
951 ReferenceAction::SetNull => builder.null_on_delete(),
952 ReferenceAction::NoAction => builder.no_action_on_delete(),
953 };
954 } else {
955 builder = builder.restrict_on_delete();
956 }
957 if let Some(action) = reference.on_update {
958 builder = match action {
959 ReferenceAction::Cascade => builder.cascade_on_update(),
960 ReferenceAction::Restrict => builder.restrict_on_update(),
961 ReferenceAction::SetNull => builder.null_on_update(),
962 ReferenceAction::NoAction => builder.no_action_on_update(),
963 };
964 }
965 let _ = builder;
966}
967
968fn relation_column_type(field: &Field) -> crate::ColumnType {
969 match &field.ty {
970 FieldType::String => crate::ColumnType::Varchar(255),
971 FieldType::Text => crate::ColumnType::Text,
972 FieldType::Integer => crate::ColumnType::Integer,
973 FieldType::BigInt => crate::ColumnType::BigInt,
974 FieldType::Boolean => crate::ColumnType::Bool,
975 FieldType::Date => crate::ColumnType::Date,
976 FieldType::Time => crate::ColumnType::Time,
977 FieldType::DateTime => crate::ColumnType::DateTime,
978 FieldType::Timestamp | FieldType::TimestampTz => crate::ColumnType::Timestamp,
979 FieldType::Decimal(precision, scale) => crate::ColumnType::Decimal(*precision, *scale),
980 FieldType::Float => crate::ColumnType::Float,
981 FieldType::Double => crate::ColumnType::Double,
982 FieldType::Json => crate::ColumnType::Json,
983 FieldType::Uuid => crate::ColumnType::Uuid,
984 FieldType::Varchar(length) => crate::ColumnType::Varchar(*length),
985 FieldType::Id => crate::ColumnType::BigInt,
986 FieldType::RememberToken => crate::ColumnType::Varchar(100),
987 FieldType::Custom(name) => crate::ColumnType::Custom(name.clone()),
988 FieldType::Implicit => crate::ColumnType::BigInt,
989 }
990}
991
992async fn add_relation_column(
993 ctx: &mut MigrationContext<'_>,
994 table_name: &str,
995 field: &Field,
996) -> Result<(), MigrationError> {
997 let Some(reference) = &field.reference else {
998 return Ok(());
999 };
1000
1001 let sql = format!(
1002 "alter table {} add column {};",
1003 ctx.dialect().quote_ident(table_name),
1004 render_inline_relation_column(ctx.dialect(), field, reference)
1005 );
1006 ctx.execute_raw_blueprint(&sql).await?;
1007
1008 if field.index {
1009 let index = crate::IndexBlueprint::new(
1010 &default_index_name(table_name, std::slice::from_ref(&field.name)),
1011 table_name,
1012 [field.name.as_str()],
1013 );
1014 ctx.execute_raw_blueprint(&index.create_sql(ctx.dialect()))
1015 .await?;
1016 }
1017
1018 Ok(())
1019}
1020
1021fn render_inline_relation_column(
1022 dialect: crate::SchemaDialect,
1023 field: &Field,
1024 reference: &Reference,
1025) -> String {
1026 let column = crate::schema::blueprint::ColumnDef {
1027 name: field.name.clone(),
1028 ty: relation_column_type(field),
1029 nullable: matches!(field.nullable, Nullability::Nullable),
1030 primary_key: false,
1031 auto_increment: false,
1032 unique: field.unique,
1033 default_raw: field.default.clone(),
1034 };
1035 let mut out = crate::schema::render::render_column(dialect, &column);
1036 out.push_str(" references ");
1037 out.push_str(&dialect.quote_ident(&reference.table));
1038 out.push('(');
1039 out.push_str(&dialect.quote_ident(reference.column.as_deref().unwrap_or("id")));
1040 out.push(')');
1041 match reference.on_delete.unwrap_or(ReferenceAction::Restrict) {
1042 ReferenceAction::Cascade => out.push_str(" on delete cascade"),
1043 ReferenceAction::Restrict => out.push_str(" on delete restrict"),
1044 ReferenceAction::SetNull => out.push_str(" on delete set null"),
1045 ReferenceAction::NoAction => out.push_str(" on delete no action"),
1046 }
1047 if let Some(action) = reference.on_update {
1048 match action {
1049 ReferenceAction::Cascade => out.push_str(" on update cascade"),
1050 ReferenceAction::Restrict => out.push_str(" on update restrict"),
1051 ReferenceAction::SetNull => out.push_str(" on update set null"),
1052 ReferenceAction::NoAction => out.push_str(" on update no action"),
1053 }
1054 }
1055 out
1056}
1057
1058async fn apply_create_enum(
1059 ctx: &mut MigrationContext<'_>,
1060 name: &str,
1061 values: &[String],
1062) -> Result<(), MigrationError> {
1063 if matches!(ctx.dialect(), crate::SchemaDialect::Postgres) {
1064 let values = values
1065 .iter()
1066 .map(|value| format!("'{}'", value.replace('\'', "''")))
1067 .collect::<Vec<_>>()
1068 .join(", ");
1069 let sql = format!(
1070 "create type {} as enum ({values});",
1071 ctx.dialect().quote_ident(name)
1072 );
1073 ctx.execute_raw_blueprint(&sql).await?;
1074 }
1075 Ok(())
1076}
1077
1078async fn apply_alter_enum(
1079 ctx: &mut MigrationContext<'_>,
1080 name: &str,
1081 actions: &[AlterEnumAction],
1082) -> Result<(), MigrationError> {
1083 if matches!(ctx.dialect(), crate::SchemaDialect::Postgres) {
1084 for action in actions {
1085 match action {
1086 AlterEnumAction::AddValue(value) => {
1087 let sql = format!(
1088 "alter type {} add value if not exists '{}';",
1089 ctx.dialect().quote_ident(name),
1090 value.replace('\'', "''")
1091 );
1092 ctx.execute_raw_blueprint(&sql).await?;
1093 }
1094 }
1095 }
1096 }
1097 Ok(())
1098}
1099
1100async fn apply_drop_enum(ctx: &mut MigrationContext<'_>, name: &str) -> Result<(), MigrationError> {
1101 if matches!(ctx.dialect(), crate::SchemaDialect::Postgres) {
1102 let sql = format!("drop type if exists {};", ctx.dialect().quote_ident(name));
1103 ctx.execute_raw_blueprint(&sql).await?;
1104 }
1105 Ok(())
1106}
1107
1108async fn apply_rename_enum(
1109 ctx: &mut MigrationContext<'_>,
1110 from: &str,
1111 to: &str,
1112) -> Result<(), MigrationError> {
1113 if matches!(ctx.dialect(), crate::SchemaDialect::Postgres) {
1114 let sql = format!(
1115 "alter type {} rename to {};",
1116 ctx.dialect().quote_ident(from),
1117 ctx.dialect().quote_ident(to)
1118 );
1119 ctx.execute_raw_blueprint(&sql).await?;
1120 }
1121 Ok(())
1122}
1123
1124impl From<quex::Pool> for Migrator {
1125 fn from(pool: quex::Pool) -> Self {
1126 #[allow(unreachable_patterns)]
1127 let pool = match pool.driver() {
1128 #[cfg(feature = "sqlite")]
1129 quex::Driver::Sqlite => MigrationPool::Sqlite(pool),
1130 #[cfg(feature = "postgres")]
1131 quex::Driver::Pgsql => MigrationPool::Postgres(pool),
1132 #[cfg(feature = "mariadb")]
1133 quex::Driver::Mysql => MigrationPool::Mariadb(pool),
1134 other => panic!("unsupported pool driver for migrations: {other:?}"),
1135 };
1136 Self {
1137 pool,
1138 migration_dir: default_migration_dir(),
1139 }
1140 }
1141}
1142
1143impl From<&quex::Pool> for Migrator {
1144 fn from(pool: &quex::Pool) -> Self {
1145 pool.clone().into()
1146 }
1147}