refinery_core/traits/
async.rs

1use crate::error::WrapMigrationError;
2use crate::traits::{
3    insert_migration_query, verify_migrations, ASSERT_MIGRATIONS_TABLE_QUERY,
4    GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY,
5};
6use crate::{Error, Migration, Report, Target};
7
8use async_trait::async_trait;
9use std::string::ToString;
10
11#[async_trait]
12pub trait AsyncTransaction {
13    type Error: std::error::Error + Send + Sync + 'static;
14
15    async fn execute(&mut self, query: &[&str]) -> Result<usize, Self::Error>;
16}
17
18#[async_trait]
19pub trait AsyncQuery<T>: AsyncTransaction {
20    async fn query(&mut self, query: &str) -> Result<T, Self::Error>;
21}
22
23async fn migrate<T: AsyncTransaction>(
24    transaction: &mut T,
25    migrations: Vec<Migration>,
26    target: Target,
27    migration_table_name: &str,
28) -> Result<Report, Error> {
29    let mut applied_migrations = vec![];
30
31    for mut migration in migrations.into_iter() {
32        if let Target::Version(input_target) = target {
33            if input_target < migration.version() {
34                log::info!(
35                    "stopping at migration: {}, due to user option",
36                    input_target
37                );
38                break;
39            }
40        }
41
42        log::info!("applying migration: {}", migration);
43        migration.set_applied();
44        let update_query = insert_migration_query(&migration, migration_table_name);
45        transaction
46            .execute(&[
47                migration.sql().as_ref().expect("sql must be Some!"),
48                &update_query,
49            ])
50            .await
51            .migration_err(
52                &format!("error applying migration {}", migration),
53                Some(&applied_migrations),
54            )?;
55        applied_migrations.push(migration);
56    }
57    Ok(Report::new(applied_migrations))
58}
59
60async fn migrate_grouped<T: AsyncTransaction>(
61    transaction: &mut T,
62    migrations: Vec<Migration>,
63    target: Target,
64    migration_table_name: &str,
65) -> Result<Report, Error> {
66    let mut grouped_migrations = Vec::new();
67    let mut applied_migrations = Vec::new();
68
69    for mut migration in migrations.into_iter() {
70        if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
71            if input_target < migration.version() {
72                break;
73            }
74        }
75
76        migration.set_applied();
77        let query = insert_migration_query(&migration, migration_table_name);
78
79        let sql = migration.sql().expect("sql must be Some!").to_string();
80
81        // If Target is Fake, we only update schema migrations table
82        if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
83            applied_migrations.push(migration);
84            grouped_migrations.push(sql);
85        }
86        grouped_migrations.push(query);
87    }
88
89    match target {
90        Target::Fake | Target::FakeVersion(_) => {
91            log::info!("not going to apply any migration as fake flag is enabled");
92        }
93        Target::Latest | Target::Version(_) => {
94            log::info!(
95                "going to apply batch migrations in single transaction: {:#?}",
96                applied_migrations.iter().map(ToString::to_string)
97            );
98        }
99    };
100
101    if let Target::Version(input_target) = target {
102        log::info!(
103            "stopping at migration: {}, due to user option",
104            input_target
105        );
106    }
107
108    let refs: Vec<&str> = grouped_migrations.iter().map(AsRef::as_ref).collect();
109
110    transaction
111        .execute(refs.as_ref())
112        .await
113        .migration_err("error applying migrations", None)?;
114
115    Ok(Report::new(applied_migrations))
116}
117
118#[async_trait]
119pub trait AsyncMigrate: AsyncQuery<Vec<Migration>>
120where
121    Self: Sized,
122{
123    // Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table
124    fn assert_migrations_table_query(migration_table_name: &str) -> String {
125        ASSERT_MIGRATIONS_TABLE_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
126    }
127
128    fn get_last_applied_migration_query(migration_table_name: &str) -> String {
129        GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
130    }
131
132    fn get_applied_migrations_query(migration_table_name: &str) -> String {
133        GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
134    }
135
136    async fn get_last_applied_migration(
137        &mut self,
138        migration_table_name: &str,
139    ) -> Result<Option<Migration>, Error> {
140        let mut migrations = self
141            .query(Self::get_last_applied_migration_query(migration_table_name).as_str())
142            .await
143            .migration_err("error getting last applied migration", None)?;
144
145        Ok(migrations.pop())
146    }
147
148    async fn get_applied_migrations(
149        &mut self,
150        migration_table_name: &str,
151    ) -> Result<Vec<Migration>, Error> {
152        let migrations = self
153            .query(Self::get_applied_migrations_query(migration_table_name).as_str())
154            .await
155            .migration_err("error getting applied migrations", None)?;
156
157        Ok(migrations)
158    }
159
160    async fn migrate(
161        &mut self,
162        migrations: &[Migration],
163        abort_divergent: bool,
164        abort_missing: bool,
165        grouped: bool,
166        target: Target,
167        migration_table_name: &str,
168    ) -> Result<Report, Error> {
169        self.execute(&[&Self::assert_migrations_table_query(migration_table_name)])
170            .await
171            .migration_err("error asserting migrations table", None)?;
172
173        let applied_migrations = self
174            .get_applied_migrations(migration_table_name)
175            .await
176            .migration_err("error getting current schema version", None)?;
177
178        let migrations = verify_migrations(
179            applied_migrations,
180            migrations.to_vec(),
181            abort_divergent,
182            abort_missing,
183        )?;
184
185        if migrations.is_empty() {
186            log::info!("no migrations to apply");
187        }
188
189        if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
190            migrate_grouped(self, migrations, target, migration_table_name).await
191        } else {
192            migrate(self, migrations, target, migration_table_name).await
193        }
194    }
195}