1use std::{collections::HashMap, future::Future, pin::Pin};
2
3#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
4use quex::{self, FromRow, Row};
5
6#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
7use crate::context::execute_table_blueprint;
8use crate::{
9 AlterTableBlueprint, BlueprintExecutor, ColumnType, IndexBlueprint, IntoSchemaColumns,
10 MigrationError, SchemaDialect, TableBlueprint,
11};
12
13pub type MigrationFuture<'a> = Pin<Box<dyn Future<Output = Result<(), MigrationError>> + 'a>>;
14
15#[allow(dead_code)]
16fn no_backend_error() -> MigrationError {
17 MigrationError::BackendNotEnabled("no backend")
18}
19
20#[derive(Clone, Copy)]
21pub struct MigrationEntry {
22 pub name: &'static str,
23 pub version: u64,
24 pub up: for<'a> fn(&'a mut MigrationContext<'a>) -> MigrationFuture<'a>,
25 pub down: for<'a> fn(&'a mut MigrationContext<'a>) -> MigrationFuture<'a>,
26}
27
28impl MigrationEntry {
29 pub const fn new(
30 name: &'static str,
31 version: u64,
32 up: for<'a> fn(&'a mut MigrationContext<'a>) -> MigrationFuture<'a>,
33 down: for<'a> fn(&'a mut MigrationContext<'a>) -> MigrationFuture<'a>,
34 ) -> Self {
35 Self {
36 name,
37 version,
38 up,
39 down,
40 }
41 }
42}
43
44inventory::collect!(MigrationEntry);
45
46#[allow(async_fn_in_trait)]
47pub trait Migration {
48 async fn up(ctx: &mut MigrationContext<'_>) -> Result<(), MigrationError>;
49 async fn down(ctx: &mut MigrationContext<'_>) -> Result<(), MigrationError>;
50}
51
52#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
53enum MigrationExecutor<'a> {
54 Pool(&'a quex::Pool),
55 Transaction(&'a mut quex::PoolTransaction),
56}
57
58#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
59type ColumnTypeCache = HashMap<(String, String), ColumnType>;
60
61#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
62impl MigrationExecutor<'_> {
63 async fn execute_raw(&mut self, sql: &str) -> Result<u64, MigrationError> {
64 let result = match self {
65 Self::Pool(pool) => quex::query(sql).execute(*pool).await?,
66 Self::Transaction(tx) => quex::query(sql).execute(&mut **tx).await?,
67 };
68 Ok(result.rows_affected)
69 }
70}
71
72#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
73struct SqliteColumnTypeRow {
74 data_type: String,
75}
76
77#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
78impl FromRow for SqliteColumnTypeRow {
79 fn from_row(row: &Row) -> quex::Result<Self> {
80 Ok(Self {
81 data_type: row.get("type")?,
82 })
83 }
84}
85
86#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
87struct InformationSchemaColumnRow {
88 data_type: String,
89 udt_name: Option<String>,
90 character_maximum_length: Option<i64>,
91 numeric_precision: Option<i64>,
92 numeric_scale: Option<i64>,
93}
94
95#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
96impl FromRow for InformationSchemaColumnRow {
97 fn from_row(row: &Row) -> quex::Result<Self> {
98 Ok(Self {
99 data_type: row.get("data_type")?,
100 udt_name: row.get("udt_name")?,
101 character_maximum_length: row.get("character_maximum_length")?,
102 numeric_precision: row.get("numeric_precision")?,
103 numeric_scale: row.get("numeric_scale")?,
104 })
105 }
106}
107
108#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
109fn quote_string_literal(value: &str) -> String {
110 format!("'{}'", value.replace('\'', "''"))
111}
112
113#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
114fn parse_sqlite_column_type(data_type: &str) -> Result<ColumnType, MigrationError> {
115 parse_column_type_string(data_type)
116}
117
118#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
119fn parse_information_schema_column_type(
120 row: &InformationSchemaColumnRow,
121) -> Result<ColumnType, MigrationError> {
122 if matches!(row.udt_name.as_deref(), Some("uuid")) {
123 return Ok(ColumnType::Uuid);
124 }
125 match row.data_type.as_str() {
126 "character varying" | "varchar" => Ok(ColumnType::Varchar(
127 row.character_maximum_length.unwrap_or(255) as u32,
128 )),
129 "character" | "char" => Ok(ColumnType::Char(
130 row.character_maximum_length.unwrap_or(1) as u32
131 )),
132 "numeric" | "decimal" => Ok(ColumnType::Decimal(
133 row.numeric_precision.unwrap_or(10) as u32,
134 row.numeric_scale.unwrap_or(0) as u32,
135 )),
136 other => parse_column_type_string(other),
137 }
138}
139
140#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
141fn parse_column_type_string(data_type: &str) -> Result<ColumnType, MigrationError> {
142 let normalized = data_type.trim().to_ascii_lowercase();
143 if normalized.starts_with("varchar(") && normalized.ends_with(')') {
144 let inner = &normalized["varchar(".len()..normalized.len() - 1];
145 let length = inner
146 .trim()
147 .parse::<u32>()
148 .map_err(|_| MigrationError::UnsupportedColumnType(data_type.to_owned()))?;
149 return Ok(ColumnType::Varchar(length));
150 }
151 if normalized.starts_with("char(") && normalized.ends_with(')') {
152 let inner = &normalized["char(".len()..normalized.len() - 1];
153 let length = inner
154 .trim()
155 .parse::<u32>()
156 .map_err(|_| MigrationError::UnsupportedColumnType(data_type.to_owned()))?;
157 return Ok(ColumnType::Char(length));
158 }
159 if normalized.starts_with("decimal(") && normalized.ends_with(')') {
160 let inner = &normalized["decimal(".len()..normalized.len() - 1];
161 let mut parts = inner.split(',').map(str::trim);
162 let precision = parts
163 .next()
164 .ok_or_else(|| MigrationError::UnsupportedColumnType(data_type.to_owned()))?
165 .parse::<u32>()
166 .map_err(|_| MigrationError::UnsupportedColumnType(data_type.to_owned()))?;
167 let scale = parts
168 .next()
169 .ok_or_else(|| MigrationError::UnsupportedColumnType(data_type.to_owned()))?
170 .parse::<u32>()
171 .map_err(|_| MigrationError::UnsupportedColumnType(data_type.to_owned()))?;
172 return Ok(ColumnType::Decimal(precision, scale));
173 }
174 match normalized.as_str() {
175 "integer" | "int" => Ok(ColumnType::Integer),
176 "bigint" => Ok(ColumnType::BigInt),
177 "boolean" | "bool" => Ok(ColumnType::Bool),
178 "text" => Ok(ColumnType::Text),
179 "date" => Ok(ColumnType::Date),
180 "time" | "time without time zone" => Ok(ColumnType::Time),
181 "timestamp" | "timestamp without time zone" | "timestamp with time zone" | "datetime" => {
182 Ok(ColumnType::Timestamp)
183 }
184 "json" | "jsonb" => Ok(ColumnType::Json),
185 "uuid" => Ok(ColumnType::Uuid),
186 "real" | "float" => Ok(ColumnType::Float),
187 "double precision" | "double" => Ok(ColumnType::Double),
188 other => Err(MigrationError::UnsupportedColumnType(other.to_owned())),
189 }
190}
191
192#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
193async fn resolve_column_type_from_pool(
194 dialect: SchemaDialect,
195 pool: &quex::Pool,
196 table: &str,
197 column: &str,
198) -> Result<ColumnType, MigrationError> {
199 match dialect {
200 SchemaDialect::Sqlite => {
201 let sql = format!(
202 "select type from pragma_table_info({}) where name = ? limit 1",
203 quote_string_literal(table)
204 );
205 let row = quex::query(&sql)
206 .bind(column)
207 .one::<SqliteColumnTypeRow>(pool)
208 .await?;
209 parse_sqlite_column_type(&row.data_type)
210 }
211 SchemaDialect::Postgres => {
212 let row = quex::query(
213 "select data_type, udt_name, character_maximum_length, numeric_precision, numeric_scale \
214 from information_schema.columns \
215 where table_schema = current_schema() and table_name = ? and column_name = ? \
216 limit 1",
217 )
218 .bind(table)
219 .bind(column)
220 .one::<InformationSchemaColumnRow>(pool)
221 .await?;
222 parse_information_schema_column_type(&row)
223 }
224 SchemaDialect::MariaDb => {
225 let row = quex::query(
226 "select data_type, data_type as udt_name, character_maximum_length, numeric_precision, numeric_scale \
227 from information_schema.columns \
228 where table_schema = database() and table_name = ? and column_name = ? \
229 limit 1",
230 )
231 .bind(table)
232 .bind(column)
233 .one::<InformationSchemaColumnRow>(pool)
234 .await?;
235 parse_information_schema_column_type(&row)
236 }
237 }
238}
239
240#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mariadb"))]
241async fn resolve_column_type_from_tx(
242 dialect: SchemaDialect,
243 tx: &mut quex::PoolTransaction,
244 table: &str,
245 column: &str,
246) -> Result<ColumnType, MigrationError> {
247 match dialect {
248 SchemaDialect::Sqlite => {
249 let sql = format!(
250 "select type from pragma_table_info({}) where name = ? limit 1",
251 quote_string_literal(table)
252 );
253 let row = quex::query(&sql)
254 .bind(column)
255 .one::<SqliteColumnTypeRow>(&mut *tx)
256 .await?;
257 parse_sqlite_column_type(&row.data_type)
258 }
259 SchemaDialect::Postgres => {
260 let row = quex::query(
261 "select data_type, udt_name, character_maximum_length, numeric_precision, numeric_scale \
262 from information_schema.columns \
263 where table_schema = current_schema() and table_name = ? and column_name = ? \
264 limit 1",
265 )
266 .bind(table)
267 .bind(column)
268 .one::<InformationSchemaColumnRow>(&mut *tx)
269 .await?;
270 parse_information_schema_column_type(&row)
271 }
272 SchemaDialect::MariaDb => {
273 let row = quex::query(
274 "select data_type, data_type as udt_name, character_maximum_length, numeric_precision, numeric_scale \
275 from information_schema.columns \
276 where table_schema = database() and table_name = ? and column_name = ? \
277 limit 1",
278 )
279 .bind(table)
280 .bind(column)
281 .one::<InformationSchemaColumnRow>(&mut *tx)
282 .await?;
283 parse_information_schema_column_type(&row)
284 }
285 }
286}
287
288macro_rules! define_backend {
289 (
290 feature =
291 $feature:literal,context =
292 $context:ident,entry =
293 $entry:ident,entry_trait =
294 $entry_trait:ident,dialect =
295 $dialect:expr
296 ) => {
297 #[cfg(feature = $feature)]
298 pub struct $context<'a> {
299 executor: MigrationExecutor<'a>,
300 column_type_cache: ColumnTypeCache,
301 }
302
303 #[cfg(feature = $feature)]
304 impl<'a> $context<'a> {
305 const SCHEMA_DIALECT: SchemaDialect = $dialect;
306
307 pub fn new(executor: &'a quex::Pool) -> Self {
308 Self {
309 executor: MigrationExecutor::Pool(executor),
310 column_type_cache: HashMap::new(),
311 }
312 }
313
314 pub fn from_transaction(executor: &'a mut quex::PoolTransaction) -> Self {
315 Self {
316 executor: MigrationExecutor::Transaction(executor),
317 column_type_cache: HashMap::new(),
318 }
319 }
320
321 pub async fn execute_raw(&mut self, sql: &str) -> Result<u64, MigrationError> {
322 self.executor.execute_raw(sql).await
323 }
324
325 pub async fn column_type(
326 &mut self,
327 table: &str,
328 column: &str,
329 ) -> Result<ColumnType, MigrationError> {
330 let cache_key = (table.to_owned(), column.to_owned());
331 if let Some(cached) = self.column_type_cache.get(&cache_key) {
332 return Ok(cached.clone());
333 }
334
335 let resolved = match &mut self.executor {
336 MigrationExecutor::Pool(pool) => {
337 resolve_column_type_from_pool(Self::SCHEMA_DIALECT, pool, table, column)
338 .await
339 }
340 MigrationExecutor::Transaction(tx) => {
341 resolve_column_type_from_tx(Self::SCHEMA_DIALECT, tx, table, column).await
342 }
343 }?;
344
345 self.column_type_cache.insert(cache_key, resolved.clone());
346
347 Ok(resolved)
348 }
349
350 pub async fn create(
351 &mut self,
352 name: &str,
353 build: impl FnOnce(&mut TableBlueprint),
354 ) -> Result<(), MigrationError> {
355 let mut table = TableBlueprint::new(name);
356 build(&mut table);
357 execute_table_blueprint(self, table).await
358 }
359
360 pub async fn alter_table(
361 &mut self,
362 name: &str,
363 build: impl FnOnce(&mut AlterTableBlueprint),
364 ) -> Result<(), MigrationError> {
365 let mut table = AlterTableBlueprint::new(name);
366 build(&mut table);
367
368 for sql in table.sql_statements(Self::SCHEMA_DIALECT) {
369 self.execute_raw(&sql).await?;
370 }
371
372 Ok(())
373 }
374
375 pub async fn table(
376 &mut self,
377 name: &str,
378 build: impl FnOnce(&mut AlterTableBlueprint),
379 ) -> Result<(), MigrationError> {
380 self.alter_table(name, build).await
381 }
382
383 pub async fn drop(&mut self, name: &str) -> Result<(), MigrationError> {
384 let table = TableBlueprint::new(name);
385 self.execute_raw(&table.drop_sql(Self::SCHEMA_DIALECT))
386 .await?;
387 Ok(())
388 }
389
390 pub async fn create_index(
391 &mut self,
392 name: &str,
393 table: &str,
394 columns: impl IntoSchemaColumns,
395 ) -> Result<(), MigrationError> {
396 let index = IndexBlueprint::new(name, table, columns);
397 self.execute_raw(&index.create_sql(Self::SCHEMA_DIALECT))
398 .await?;
399 Ok(())
400 }
401
402 pub async fn create_unique_index(
403 &mut self,
404 name: &str,
405 table: &str,
406 columns: impl IntoSchemaColumns,
407 ) -> Result<(), MigrationError> {
408 let index = IndexBlueprint::new_unique(name, table, columns);
409 self.execute_raw(&index.create_sql(Self::SCHEMA_DIALECT))
410 .await?;
411 Ok(())
412 }
413
414 pub async fn drop_index(&mut self, name: &str) -> Result<(), MigrationError> {
415 let index = IndexBlueprint::named(name);
416 self.execute_raw(&index.drop_sql(Self::SCHEMA_DIALECT))
417 .await?;
418 Ok(())
419 }
420 }
421
422 #[cfg(feature = $feature)]
423 impl<'a> BlueprintExecutor for $context<'a> {
424 fn dialect(&self) -> SchemaDialect {
425 Self::SCHEMA_DIALECT
426 }
427
428 async fn execute_raw_blueprint(&mut self, sql: &str) -> Result<u64, MigrationError> {
429 Self::execute_raw(self, sql).await
430 }
431 }
432
433 #[cfg(feature = $feature)]
434 #[derive(Clone, Copy)]
435 pub struct $entry {
436 pub name: &'static str,
437 pub version: u64,
438 pub up: for<'a> fn(&'a mut $context<'a>) -> MigrationFuture<'a>,
439 pub down: for<'a> fn(&'a mut $context<'a>) -> MigrationFuture<'a>,
440 }
441
442 #[cfg(feature = $feature)]
443 impl $entry {
444 pub const fn new(
445 name: &'static str,
446 version: u64,
447 up: for<'a> fn(&'a mut $context<'a>) -> MigrationFuture<'a>,
448 down: for<'a> fn(&'a mut $context<'a>) -> MigrationFuture<'a>,
449 ) -> Self {
450 Self {
451 name,
452 version,
453 up,
454 down,
455 }
456 }
457 }
458
459 #[cfg(feature = $feature)]
460 inventory::collect!($entry);
461
462 #[cfg(feature = $feature)]
463 #[allow(async_fn_in_trait)]
464 pub trait $entry_trait {
465 async fn up(ctx: &mut $context<'_>) -> Result<(), MigrationError>;
466 async fn down(ctx: &mut $context<'_>) -> Result<(), MigrationError>;
467 }
468 };
469}
470
471define_backend!(
472 feature = "sqlite",
473 context = SqliteMigrationContext,
474 entry = SqliteMigrationEntry,
475 entry_trait = SqliteMigration,
476 dialect = SchemaDialect::Sqlite
477);
478
479define_backend!(
480 feature = "postgres",
481 context = PostgresMigrationContext,
482 entry = PostgresMigrationEntry,
483 entry_trait = PostgresMigration,
484 dialect = SchemaDialect::Postgres
485);
486
487define_backend!(
488 feature = "mariadb",
489 context = MariadbMigrationContext,
490 entry = MariadbMigrationEntry,
491 entry_trait = MariadbMigration,
492 dialect = SchemaDialect::MariaDb
493);
494
495pub enum MigrationContext<'a> {
496 #[cfg(feature = "sqlite")]
497 Sqlite(SqliteMigrationContext<'a>),
498 #[cfg(feature = "postgres")]
499 Postgres(PostgresMigrationContext<'a>),
500 #[cfg(feature = "mariadb")]
501 Mariadb(MariadbMigrationContext<'a>),
502 #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
503 Disabled(std::marker::PhantomData<&'a ()>),
504}
505
506impl<'a> MigrationContext<'a> {
507 pub fn dialect(&self) -> SchemaDialect {
508 match self {
509 #[cfg(feature = "sqlite")]
510 Self::Sqlite(_) => SchemaDialect::Sqlite,
511 #[cfg(feature = "postgres")]
512 Self::Postgres(_) => SchemaDialect::Postgres,
513 #[cfg(feature = "mariadb")]
514 Self::Mariadb(_) => SchemaDialect::MariaDb,
515 #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
516 Self::Disabled(_) => SchemaDialect::Sqlite,
517 }
518 }
519
520 pub async fn execute_raw(&mut self, sql: &str) -> Result<u64, MigrationError> {
521 match self {
522 #[cfg(feature = "sqlite")]
523 Self::Sqlite(ctx) => ctx.execute_raw(sql).await,
524 #[cfg(feature = "postgres")]
525 Self::Postgres(ctx) => ctx.execute_raw(sql).await,
526 #[cfg(feature = "mariadb")]
527 Self::Mariadb(ctx) => ctx.execute_raw(sql).await,
528 #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
529 Self::Disabled(_) => Err(no_backend_error()),
530 }
531 }
532
533 pub async fn column_type(
534 &mut self,
535 table: &str,
536 column: &str,
537 ) -> Result<ColumnType, MigrationError> {
538 match self {
539 #[cfg(feature = "sqlite")]
540 Self::Sqlite(ctx) => ctx.column_type(table, column).await,
541 #[cfg(feature = "postgres")]
542 Self::Postgres(ctx) => ctx.column_type(table, column).await,
543 #[cfg(feature = "mariadb")]
544 Self::Mariadb(ctx) => ctx.column_type(table, column).await,
545 #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
546 Self::Disabled(_) => Err(no_backend_error()),
547 }
548 }
549
550 pub async fn create(
551 &mut self,
552 name: &str,
553 build: impl FnOnce(&mut TableBlueprint),
554 ) -> Result<(), MigrationError> {
555 let mut build = Some(build);
556 match self {
557 #[cfg(feature = "sqlite")]
558 Self::Sqlite(ctx) => ctx.create(name, build.take().unwrap()).await,
559 #[cfg(feature = "postgres")]
560 Self::Postgres(ctx) => ctx.create(name, build.take().unwrap()).await,
561 #[cfg(feature = "mariadb")]
562 Self::Mariadb(ctx) => ctx.create(name, build.take().unwrap()).await,
563 #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
564 Self::Disabled(_) => Err(no_backend_error()),
565 }
566 }
567
568 pub async fn alter_table(
569 &mut self,
570 name: &str,
571 build: impl FnOnce(&mut AlterTableBlueprint),
572 ) -> Result<(), MigrationError> {
573 let mut build = Some(build);
574 match self {
575 #[cfg(feature = "sqlite")]
576 Self::Sqlite(ctx) => ctx.alter_table(name, build.take().unwrap()).await,
577 #[cfg(feature = "postgres")]
578 Self::Postgres(ctx) => ctx.alter_table(name, build.take().unwrap()).await,
579 #[cfg(feature = "mariadb")]
580 Self::Mariadb(ctx) => ctx.alter_table(name, build.take().unwrap()).await,
581 #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
582 Self::Disabled(_) => Err(no_backend_error()),
583 }
584 }
585
586 pub async fn table(
587 &mut self,
588 name: &str,
589 build: impl FnOnce(&mut AlterTableBlueprint),
590 ) -> Result<(), MigrationError> {
591 self.alter_table(name, build).await
592 }
593
594 pub async fn drop(&mut self, name: &str) -> Result<(), MigrationError> {
595 match self {
596 #[cfg(feature = "sqlite")]
597 Self::Sqlite(ctx) => ctx.drop(name).await,
598 #[cfg(feature = "postgres")]
599 Self::Postgres(ctx) => ctx.drop(name).await,
600 #[cfg(feature = "mariadb")]
601 Self::Mariadb(ctx) => ctx.drop(name).await,
602 #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mariadb")))]
603 Self::Disabled(_) => Err(no_backend_error()),
604 }
605 }
606
607 pub async fn create_index(
608 &mut self,
609 name: &str,
610 table: &str,
611 columns: impl IntoSchemaColumns,
612 ) -> Result<(), MigrationError> {
613 let index = IndexBlueprint::new(name, table, columns);
614 self.execute_raw(&index.create_sql(self.dialect())).await?;
615 Ok(())
616 }
617
618 pub async fn create_unique_index(
619 &mut self,
620 name: &str,
621 table: &str,
622 columns: impl IntoSchemaColumns,
623 ) -> Result<(), MigrationError> {
624 let index = IndexBlueprint::new_unique(name, table, columns);
625 self.execute_raw(&index.create_sql(self.dialect())).await?;
626 Ok(())
627 }
628
629 pub async fn drop_index(&mut self, name: &str) -> Result<(), MigrationError> {
630 let index = IndexBlueprint::named(name);
631 self.execute_raw(&index.drop_sql(self.dialect())).await?;
632 Ok(())
633 }
634}
635
636impl<'a> BlueprintExecutor for MigrationContext<'a> {
637 fn dialect(&self) -> SchemaDialect {
638 Self::dialect(self)
639 }
640
641 async fn execute_raw_blueprint(&mut self, sql: &str) -> Result<u64, MigrationError> {
642 Self::execute_raw(self, sql).await
643 }
644}