Skip to main content

lift_migration/context/
migration.rs

1use std::{collections::HashMap, future::Future, pin::Pin};
2
3#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
4use quex::{self, FromRow, Row};
5
6#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
7use crate::context::execute_table_blueprint;
8use crate::{
9    AlterTableBlueprint, BlueprintExecutor, ColumnType, IndexBlueprint, IntoSchemaColumns,
10    MigrationError, SchemaDialect, TableBlueprint,
11};
12
13pub type MigrationFuture<'a> = Pin<Box<dyn Future<Output = Result<(), MigrationError>> + 'a>>;
14
15#[allow(dead_code)]
16fn no_backend_error() -> MigrationError {
17    MigrationError::BackendNotEnabled("no backend")
18}
19
20#[derive(Clone, Copy)]
21pub struct MigrationEntry {
22    pub name: &'static str,
23    pub version: u64,
24    pub up: for<'a> fn(&'a mut MigrationContext<'a>) -> MigrationFuture<'a>,
25    pub down: for<'a> fn(&'a mut MigrationContext<'a>) -> MigrationFuture<'a>,
26}
27
28impl MigrationEntry {
29    pub const fn new(
30        name: &'static str,
31        version: u64,
32        up: for<'a> fn(&'a mut MigrationContext<'a>) -> MigrationFuture<'a>,
33        down: for<'a> fn(&'a mut MigrationContext<'a>) -> MigrationFuture<'a>,
34    ) -> Self {
35        Self {
36            name,
37            version,
38            up,
39            down,
40        }
41    }
42}
43
44inventory::collect!(MigrationEntry);
45
46#[allow(async_fn_in_trait)]
47pub trait Migration {
48    async fn up(ctx: &mut MigrationContext<'_>) -> Result<(), MigrationError>;
49    async fn down(ctx: &mut MigrationContext<'_>) -> Result<(), MigrationError>;
50}
51
52#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
53enum MigrationExecutor<'a> {
54    Pool(&'a quex::Pool),
55    Transaction(&'a mut quex::PoolTransaction),
56}
57
58#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
59type ColumnTypeCache = HashMap<(String, String), ColumnType>;
60
61#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
62impl MigrationExecutor<'_> {
63    async fn execute_raw(&mut self, sql: &str) -> Result<u64, MigrationError> {
64        let result = match self {
65            Self::Pool(pool) => quex::query(sql).execute(*pool).await?,
66            Self::Transaction(tx) => quex::query(sql).execute(&mut **tx).await?,
67        };
68        Ok(result.rows_affected)
69    }
70}
71
72#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
73struct SqliteColumnTypeRow {
74    data_type: String,
75}
76
77#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
78impl FromRow for SqliteColumnTypeRow {
79    fn from_row(row: &Row) -> quex::Result<Self> {
80        Ok(Self {
81            data_type: row.get("type")?,
82        })
83    }
84}
85
86#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
87struct InformationSchemaColumnRow {
88    data_type: String,
89    udt_name: Option<String>,
90    character_maximum_length: Option<i64>,
91    numeric_precision: Option<i64>,
92    numeric_scale: Option<i64>,
93}
94
95#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
96impl FromRow for InformationSchemaColumnRow {
97    fn from_row(row: &Row) -> quex::Result<Self> {
98        Ok(Self {
99            data_type: row.get("data_type")?,
100            udt_name: row.get("udt_name")?,
101            character_maximum_length: row.get("character_maximum_length")?,
102            numeric_precision: row.get("numeric_precision")?,
103            numeric_scale: row.get("numeric_scale")?,
104        })
105    }
106}
107
108#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
109fn quote_string_literal(value: &str) -> String {
110    format!("'{}'", value.replace('\'', "''"))
111}
112
113#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
114fn parse_sqlite_column_type(data_type: &str) -> Result<ColumnType, MigrationError> {
115    parse_column_type_string(data_type)
116}
117
118#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
119fn parse_information_schema_column_type(
120    row: &InformationSchemaColumnRow,
121) -> Result<ColumnType, MigrationError> {
122    if matches!(row.udt_name.as_deref(), Some("uuid")) {
123        return Ok(ColumnType::Uuid);
124    }
125    match row.data_type.as_str() {
126        "character varying" | "varchar" => Ok(ColumnType::Varchar(
127            row.character_maximum_length.unwrap_or(255) as u32,
128        )),
129        "character" | "char" => Ok(ColumnType::Char(
130            row.character_maximum_length.unwrap_or(1) as u32
131        )),
132        "numeric" | "decimal" => Ok(ColumnType::Decimal(
133            row.numeric_precision.unwrap_or(10) as u32,
134            row.numeric_scale.unwrap_or(0) as u32,
135        )),
136        other => parse_column_type_string(other),
137    }
138}
139
140#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
141fn parse_column_type_string(data_type: &str) -> Result<ColumnType, MigrationError> {
142    let normalized = data_type.trim().to_ascii_lowercase();
143    if normalized.starts_with("varchar(") && normalized.ends_with(')') {
144        let inner = &normalized["varchar(".len()..normalized.len() - 1];
145        let length = inner
146            .trim()
147            .parse::<u32>()
148            .map_err(|_| MigrationError::UnsupportedColumnType(data_type.to_owned()))?;
149        return Ok(ColumnType::Varchar(length));
150    }
151    if normalized.starts_with("char(") && normalized.ends_with(')') {
152        let inner = &normalized["char(".len()..normalized.len() - 1];
153        let length = inner
154            .trim()
155            .parse::<u32>()
156            .map_err(|_| MigrationError::UnsupportedColumnType(data_type.to_owned()))?;
157        return Ok(ColumnType::Char(length));
158    }
159    if normalized.starts_with("decimal(") && normalized.ends_with(')') {
160        let inner = &normalized["decimal(".len()..normalized.len() - 1];
161        let mut parts = inner.split(',').map(str::trim);
162        let precision = parts
163            .next()
164            .ok_or_else(|| MigrationError::UnsupportedColumnType(data_type.to_owned()))?
165            .parse::<u32>()
166            .map_err(|_| MigrationError::UnsupportedColumnType(data_type.to_owned()))?;
167        let scale = parts
168            .next()
169            .ok_or_else(|| MigrationError::UnsupportedColumnType(data_type.to_owned()))?
170            .parse::<u32>()
171            .map_err(|_| MigrationError::UnsupportedColumnType(data_type.to_owned()))?;
172        return Ok(ColumnType::Decimal(precision, scale));
173    }
174    match normalized.as_str() {
175        "integer" | "int" => Ok(ColumnType::Integer),
176        "bigint" => Ok(ColumnType::BigInt),
177        "boolean" | "bool" => Ok(ColumnType::Bool),
178        "text" => Ok(ColumnType::Text),
179        "date" => Ok(ColumnType::Date),
180        "time" | "time without time zone" => Ok(ColumnType::Time),
181        "timestamp" | "timestamp without time zone" | "timestamp with time zone" | "datetime" => {
182            Ok(ColumnType::Timestamp)
183        }
184        "json" | "jsonb" => Ok(ColumnType::Json),
185        "uuid" => Ok(ColumnType::Uuid),
186        "real" | "float" => Ok(ColumnType::Float),
187        "double precision" | "double" => Ok(ColumnType::Double),
188        other => Err(MigrationError::UnsupportedColumnType(other.to_owned())),
189    }
190}
191
192#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
193async fn resolve_column_type_from_pool(
194    dialect: SchemaDialect,
195    pool: &quex::Pool,
196    table: &str,
197    column: &str,
198) -> Result<ColumnType, MigrationError> {
199    match dialect {
200        SchemaDialect::Sqlite => {
201            let sql = format!(
202                "select type from pragma_table_info({}) where name = ? limit 1",
203                quote_string_literal(table)
204            );
205            let row = quex::query(&sql)
206                .bind(column)
207                .one::<SqliteColumnTypeRow>(pool)
208                .await?;
209            parse_sqlite_column_type(&row.data_type)
210        }
211        SchemaDialect::Postgres => {
212            let row = quex::query(
213                "select data_type, udt_name, character_maximum_length, numeric_precision, numeric_scale \
214                 from information_schema.columns \
215                 where table_schema = current_schema() and table_name = ? and column_name = ? \
216                 limit 1",
217            )
218            .bind(table)
219            .bind(column)
220            .one::<InformationSchemaColumnRow>(pool)
221            .await?;
222            parse_information_schema_column_type(&row)
223        }
224        SchemaDialect::MariaDb => {
225            let row = quex::query(
226                "select data_type, data_type as udt_name, character_maximum_length, numeric_precision, numeric_scale \
227                 from information_schema.columns \
228                 where table_schema = database() and table_name = ? and column_name = ? \
229                 limit 1",
230            )
231            .bind(table)
232            .bind(column)
233            .one::<InformationSchemaColumnRow>(pool)
234            .await?;
235            parse_information_schema_column_type(&row)
236        }
237    }
238}
239
240#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
241async fn resolve_column_type_from_tx(
242    dialect: SchemaDialect,
243    tx: &mut quex::PoolTransaction,
244    table: &str,
245    column: &str,
246) -> Result<ColumnType, MigrationError> {
247    match dialect {
248        SchemaDialect::Sqlite => {
249            let sql = format!(
250                "select type from pragma_table_info({}) where name = ? limit 1",
251                quote_string_literal(table)
252            );
253            let row = quex::query(&sql)
254                .bind(column)
255                .one::<SqliteColumnTypeRow>(&mut *tx)
256                .await?;
257            parse_sqlite_column_type(&row.data_type)
258        }
259        SchemaDialect::Postgres => {
260            let row = quex::query(
261                "select data_type, udt_name, character_maximum_length, numeric_precision, numeric_scale \
262                 from information_schema.columns \
263                 where table_schema = current_schema() and table_name = ? and column_name = ? \
264                 limit 1",
265            )
266            .bind(table)
267            .bind(column)
268            .one::<InformationSchemaColumnRow>(&mut *tx)
269            .await?;
270            parse_information_schema_column_type(&row)
271        }
272        SchemaDialect::MariaDb => {
273            let row = quex::query(
274                "select data_type, data_type as udt_name, character_maximum_length, numeric_precision, numeric_scale \
275                 from information_schema.columns \
276                 where table_schema = database() and table_name = ? and column_name = ? \
277                 limit 1",
278            )
279            .bind(table)
280            .bind(column)
281            .one::<InformationSchemaColumnRow>(&mut *tx)
282            .await?;
283            parse_information_schema_column_type(&row)
284        }
285    }
286}
287
288macro_rules! define_backend {
289    (
290        feature =
291        $feature:literal,context =
292        $context:ident,entry =
293        $entry:ident,entry_trait =
294        $entry_trait:ident,dialect =
295        $dialect:expr
296    ) => {
297        #[cfg(feature = $feature)]
298        pub struct $context<'a> {
299            executor: MigrationExecutor<'a>,
300            column_type_cache: ColumnTypeCache,
301        }
302
303        #[cfg(feature = $feature)]
304        impl<'a> $context<'a> {
305            const SCHEMA_DIALECT: SchemaDialect = $dialect;
306
307            pub fn new(executor: &'a quex::Pool) -> Self {
308                Self {
309                    executor: MigrationExecutor::Pool(executor),
310                    column_type_cache: HashMap::new(),
311                }
312            }
313
314            pub fn from_transaction(executor: &'a mut quex::PoolTransaction) -> Self {
315                Self {
316                    executor: MigrationExecutor::Transaction(executor),
317                    column_type_cache: HashMap::new(),
318                }
319            }
320
321            pub async fn execute_raw(&mut self, sql: &str) -> Result<u64, MigrationError> {
322                self.executor.execute_raw(sql).await
323            }
324
325            pub async fn column_type(
326                &mut self,
327                table: &str,
328                column: &str,
329            ) -> Result<ColumnType, MigrationError> {
330                let cache_key = (table.to_owned(), column.to_owned());
331                if let Some(cached) = self.column_type_cache.get(&cache_key) {
332                    return Ok(cached.clone());
333                }
334
335                let resolved = match &mut self.executor {
336                    MigrationExecutor::Pool(pool) => {
337                        resolve_column_type_from_pool(Self::SCHEMA_DIALECT, pool, table, column)
338                            .await
339                    }
340                    MigrationExecutor::Transaction(tx) => {
341                        resolve_column_type_from_tx(Self::SCHEMA_DIALECT, tx, table, column).await
342                    }
343                }?;
344
345                self.column_type_cache.insert(cache_key, resolved.clone());
346
347                Ok(resolved)
348            }
349
350            pub async fn create(
351                &mut self,
352                name: &str,
353                build: impl FnOnce(&mut TableBlueprint),
354            ) -> Result<(), MigrationError> {
355                let mut table = TableBlueprint::new(name);
356                build(&mut table);
357                execute_table_blueprint(self, table).await
358            }
359
360            pub async fn alter_table(
361                &mut self,
362                name: &str,
363                build: impl FnOnce(&mut AlterTableBlueprint),
364            ) -> Result<(), MigrationError> {
365                let mut table = AlterTableBlueprint::new(name);
366                build(&mut table);
367
368                for sql in table.sql_statements(Self::SCHEMA_DIALECT) {
369                    self.execute_raw(&sql).await?;
370                }
371
372                Ok(())
373            }
374
375            pub async fn table(
376                &mut self,
377                name: &str,
378                build: impl FnOnce(&mut AlterTableBlueprint),
379            ) -> Result<(), MigrationError> {
380                self.alter_table(name, build).await
381            }
382
383            pub async fn drop(&mut self, name: &str) -> Result<(), MigrationError> {
384                let table = TableBlueprint::new(name);
385                self.execute_raw(&table.drop_sql(Self::SCHEMA_DIALECT))
386                    .await?;
387                Ok(())
388            }
389
390            pub async fn create_index(
391                &mut self,
392                name: &str,
393                table: &str,
394                columns: impl IntoSchemaColumns,
395            ) -> Result<(), MigrationError> {
396                let index = IndexBlueprint::new(name, table, columns);
397                self.execute_raw(&index.create_sql(Self::SCHEMA_DIALECT))
398                    .await?;
399                Ok(())
400            }
401
402            pub async fn create_unique_index(
403                &mut self,
404                name: &str,
405                table: &str,
406                columns: impl IntoSchemaColumns,
407            ) -> Result<(), MigrationError> {
408                let index = IndexBlueprint::new_unique(name, table, columns);
409                self.execute_raw(&index.create_sql(Self::SCHEMA_DIALECT))
410                    .await?;
411                Ok(())
412            }
413
414            pub async fn drop_index(&mut self, name: &str) -> Result<(), MigrationError> {
415                let index = IndexBlueprint::named(name);
416                self.execute_raw(&index.drop_sql(Self::SCHEMA_DIALECT))
417                    .await?;
418                Ok(())
419            }
420        }
421
422        #[cfg(feature = $feature)]
423        impl<'a> BlueprintExecutor for $context<'a> {
424            fn dialect(&self) -> SchemaDialect {
425                Self::SCHEMA_DIALECT
426            }
427
428            async fn execute_raw_blueprint(&mut self, sql: &str) -> Result<u64, MigrationError> {
429                Self::execute_raw(self, sql).await
430            }
431        }
432
433        #[cfg(feature = $feature)]
434        #[derive(Clone, Copy)]
435        pub struct $entry {
436            pub name: &'static str,
437            pub version: u64,
438            pub up: for<'a> fn(&'a mut $context<'a>) -> MigrationFuture<'a>,
439            pub down: for<'a> fn(&'a mut $context<'a>) -> MigrationFuture<'a>,
440        }
441
442        #[cfg(feature = $feature)]
443        impl $entry {
444            pub const fn new(
445                name: &'static str,
446                version: u64,
447                up: for<'a> fn(&'a mut $context<'a>) -> MigrationFuture<'a>,
448                down: for<'a> fn(&'a mut $context<'a>) -> MigrationFuture<'a>,
449            ) -> Self {
450                Self {
451                    name,
452                    version,
453                    up,
454                    down,
455                }
456            }
457        }
458
459        #[cfg(feature = $feature)]
460        inventory::collect!($entry);
461
462        #[cfg(feature = $feature)]
463        #[allow(async_fn_in_trait)]
464        pub trait $entry_trait {
465            async fn up(ctx: &mut $context<'_>) -> Result<(), MigrationError>;
466            async fn down(ctx: &mut $context<'_>) -> Result<(), MigrationError>;
467        }
468    };
469}
470
471define_backend!(
472    feature = "sqlite",
473    context = SqliteMigrationContext,
474    entry = SqliteMigrationEntry,
475    entry_trait = SqliteMigration,
476    dialect = SchemaDialect::Sqlite
477);
478
479define_backend!(
480    feature = "postgres",
481    context = PostgresMigrationContext,
482    entry = PostgresMigrationEntry,
483    entry_trait = PostgresMigration,
484    dialect = SchemaDialect::Postgres
485);
486
487define_backend!(
488    feature = "mariadb",
489    context = MariadbMigrationContext,
490    entry = MariadbMigrationEntry,
491    entry_trait = MariadbMigration,
492    dialect = SchemaDialect::MariaDb
493);
494
495pub enum MigrationContext<'a> {
496    #[cfg(feature = "sqlite")]
497    Sqlite(SqliteMigrationContext<'a>),
498    #[cfg(feature = "postgres")]
499    Postgres(PostgresMigrationContext<'a>),
500    #[cfg(feature = "mariadb")]
501    Mariadb(MariadbMigrationContext<'a>),
502    #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
503    Disabled(std::marker::PhantomData<&'a ()>),
504}
505
506impl<'a> MigrationContext<'a> {
507    pub fn dialect(&self) -> SchemaDialect {
508        match self {
509            #[cfg(feature = "sqlite")]
510            Self::Sqlite(_) => SchemaDialect::Sqlite,
511            #[cfg(feature = "postgres")]
512            Self::Postgres(_) => SchemaDialect::Postgres,
513            #[cfg(feature = "mariadb")]
514            Self::Mariadb(_) => SchemaDialect::MariaDb,
515            #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
516            Self::Disabled(_) => SchemaDialect::Sqlite,
517        }
518    }
519
520    pub async fn execute_raw(&mut self, sql: &str) -> Result<u64, MigrationError> {
521        match self {
522            #[cfg(feature = "sqlite")]
523            Self::Sqlite(ctx) => ctx.execute_raw(sql).await,
524            #[cfg(feature = "postgres")]
525            Self::Postgres(ctx) => ctx.execute_raw(sql).await,
526            #[cfg(feature = "mariadb")]
527            Self::Mariadb(ctx) => ctx.execute_raw(sql).await,
528            #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
529            Self::Disabled(_) => Err(no_backend_error()),
530        }
531    }
532
533    pub async fn column_type(
534        &mut self,
535        table: &str,
536        column: &str,
537    ) -> Result<ColumnType, MigrationError> {
538        match self {
539            #[cfg(feature = "sqlite")]
540            Self::Sqlite(ctx) => ctx.column_type(table, column).await,
541            #[cfg(feature = "postgres")]
542            Self::Postgres(ctx) => ctx.column_type(table, column).await,
543            #[cfg(feature = "mariadb")]
544            Self::Mariadb(ctx) => ctx.column_type(table, column).await,
545            #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
546            Self::Disabled(_) => Err(no_backend_error()),
547        }
548    }
549
550    pub async fn create(
551        &mut self,
552        name: &str,
553        build: impl FnOnce(&mut TableBlueprint),
554    ) -> Result<(), MigrationError> {
555        let mut build = Some(build);
556        match self {
557            #[cfg(feature = "sqlite")]
558            Self::Sqlite(ctx) => ctx.create(name, build.take().unwrap()).await,
559            #[cfg(feature = "postgres")]
560            Self::Postgres(ctx) => ctx.create(name, build.take().unwrap()).await,
561            #[cfg(feature = "mariadb")]
562            Self::Mariadb(ctx) => ctx.create(name, build.take().unwrap()).await,
563            #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
564            Self::Disabled(_) => Err(no_backend_error()),
565        }
566    }
567
568    pub async fn alter_table(
569        &mut self,
570        name: &str,
571        build: impl FnOnce(&mut AlterTableBlueprint),
572    ) -> Result<(), MigrationError> {
573        let mut build = Some(build);
574        match self {
575            #[cfg(feature = "sqlite")]
576            Self::Sqlite(ctx) => ctx.alter_table(name, build.take().unwrap()).await,
577            #[cfg(feature = "postgres")]
578            Self::Postgres(ctx) => ctx.alter_table(name, build.take().unwrap()).await,
579            #[cfg(feature = "mariadb")]
580            Self::Mariadb(ctx) => ctx.alter_table(name, build.take().unwrap()).await,
581            #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
582            Self::Disabled(_) => Err(no_backend_error()),
583        }
584    }
585
586    pub async fn table(
587        &mut self,
588        name: &str,
589        build: impl FnOnce(&mut AlterTableBlueprint),
590    ) -> Result<(), MigrationError> {
591        self.alter_table(name, build).await
592    }
593
594    pub async fn drop(&mut self, name: &str) -> Result<(), MigrationError> {
595        match self {
596            #[cfg(feature = "sqlite")]
597            Self::Sqlite(ctx) => ctx.drop(name).await,
598            #[cfg(feature = "postgres")]
599            Self::Postgres(ctx) => ctx.drop(name).await,
600            #[cfg(feature = "mariadb")]
601            Self::Mariadb(ctx) => ctx.drop(name).await,
602            #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
603            Self::Disabled(_) => Err(no_backend_error()),
604        }
605    }
606
607    pub async fn create_index(
608        &mut self,
609        name: &str,
610        table: &str,
611        columns: impl IntoSchemaColumns,
612    ) -> Result<(), MigrationError> {
613        let index = IndexBlueprint::new(name, table, columns);
614        self.execute_raw(&index.create_sql(self.dialect())).await?;
615        Ok(())
616    }
617
618    pub async fn create_unique_index(
619        &mut self,
620        name: &str,
621        table: &str,
622        columns: impl IntoSchemaColumns,
623    ) -> Result<(), MigrationError> {
624        let index = IndexBlueprint::new_unique(name, table, columns);
625        self.execute_raw(&index.create_sql(self.dialect())).await?;
626        Ok(())
627    }
628
629    pub async fn drop_index(&mut self, name: &str) -> Result<(), MigrationError> {
630        let index = IndexBlueprint::named(name);
631        self.execute_raw(&index.drop_sql(self.dialect())).await?;
632        Ok(())
633    }
634}
635
636impl<'a> BlueprintExecutor for MigrationContext<'a> {
637    fn dialect(&self) -> SchemaDialect {
638        Self::dialect(self)
639    }
640
641    async fn execute_raw_blueprint(&mut self, sql: &str) -> Result<u64, MigrationError> {
642        Self::execute_raw(self, sql).await
643    }
644}