refinery_core/traits/
async.rs

1use crate::error::WrapMigrationError;
2use crate::traits::{
3    insert_migration_query, verify_migrations, ASSERT_MIGRATIONS_TABLE_QUERY,
4    GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY,
5};
6use crate::{Error, Migration, Report, Target};
7
8use async_trait::async_trait;
9use std::string::ToString;
10
11#[async_trait]
12pub trait AsyncTransaction {
13    type Error: std::error::Error + Send + Sync + 'static;
14
15    async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
16        &mut self,
17        queries: T,
18    ) -> Result<usize, Self::Error>;
19}
20
21#[async_trait]
22pub trait AsyncQuery<T>: AsyncTransaction {
23    async fn query(&mut self, query: &str) -> Result<T, Self::Error>;
24}
25
26async fn migrate<T: AsyncTransaction>(
27    transaction: &mut T,
28    migrations: Vec<Migration>,
29    target: Target,
30    migration_table_name: &str,
31) -> Result<Report, Error> {
32    let mut applied_migrations = vec![];
33
34    for mut migration in migrations.into_iter() {
35        if let Target::Version(input_target) = target {
36            if input_target < migration.version() {
37                log::info!(
38                    "stopping at migration: {}, due to user option",
39                    input_target
40                );
41                break;
42            }
43        }
44
45        log::info!("applying migration: {}", migration);
46        migration.set_applied();
47        let update_query = insert_migration_query(&migration, migration_table_name);
48        transaction
49            .execute(
50                [
51                    migration.sql().as_ref().expect("sql must be Some!"),
52                    update_query.as_str(),
53                ]
54                .into_iter(),
55            )
56            .await
57            .migration_err(
58                &format!("error applying migration {}", migration),
59                Some(&applied_migrations),
60            )?;
61        applied_migrations.push(migration);
62    }
63    Ok(Report::new(applied_migrations))
64}
65
66async fn migrate_grouped<T: AsyncTransaction>(
67    transaction: &mut T,
68    migrations: Vec<Migration>,
69    target: Target,
70    migration_table_name: &str,
71) -> Result<Report, Error> {
72    let mut grouped_migrations = Vec::new();
73    let mut applied_migrations = Vec::new();
74
75    for mut migration in migrations.into_iter() {
76        if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
77            if input_target < migration.version() {
78                break;
79            }
80        }
81
82        migration.set_applied();
83        let query = insert_migration_query(&migration, migration_table_name);
84
85        let sql = migration.sql().expect("sql must be Some!").to_string();
86
87        // If Target is Fake, we only update schema migrations table
88        if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
89            applied_migrations.push(migration);
90            grouped_migrations.push(sql);
91        }
92        grouped_migrations.push(query);
93    }
94
95    match target {
96        Target::Fake | Target::FakeVersion(_) => {
97            log::info!("not going to apply any migration as fake flag is enabled");
98        }
99        Target::Latest | Target::Version(_) => {
100            log::info!(
101                "going to apply batch migrations in single transaction: {:#?}",
102                applied_migrations.iter().map(ToString::to_string)
103            );
104        }
105    };
106
107    if let Target::Version(input_target) = target {
108        log::info!(
109            "stopping at migration: {}, due to user option",
110            input_target
111        );
112    }
113
114    let refs = grouped_migrations.iter().map(AsRef::as_ref);
115
116    transaction
117        .execute(refs)
118        .await
119        .migration_err("error applying migrations", None)?;
120
121    Ok(Report::new(applied_migrations))
122}
123
124#[async_trait]
125pub trait AsyncMigrate: AsyncQuery<Vec<Migration>>
126where
127    Self: Sized,
128{
129    // Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table
130    fn assert_migrations_table_query(migration_table_name: &str) -> String {
131        ASSERT_MIGRATIONS_TABLE_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
132    }
133
134    fn get_last_applied_migration_query(migration_table_name: &str) -> String {
135        GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
136    }
137
138    fn get_applied_migrations_query(migration_table_name: &str) -> String {
139        GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
140    }
141
142    async fn get_last_applied_migration(
143        &mut self,
144        migration_table_name: &str,
145    ) -> Result<Option<Migration>, Error> {
146        let mut migrations = self
147            .query(Self::get_last_applied_migration_query(migration_table_name).as_str())
148            .await
149            .migration_err("error getting last applied migration", None)?;
150
151        Ok(migrations.pop())
152    }
153
154    async fn get_applied_migrations(
155        &mut self,
156        migration_table_name: &str,
157    ) -> Result<Vec<Migration>, Error> {
158        let migrations = self
159            .query(Self::get_applied_migrations_query(migration_table_name).as_str())
160            .await
161            .migration_err("error getting applied migrations", None)?;
162
163        Ok(migrations)
164    }
165
166    async fn migrate(
167        &mut self,
168        migrations: &[Migration],
169        abort_divergent: bool,
170        abort_missing: bool,
171        grouped: bool,
172        target: Target,
173        migration_table_name: &str,
174    ) -> Result<Report, Error> {
175        self.execute(
176            [Self::assert_migrations_table_query(migration_table_name).as_str()].into_iter(),
177        )
178        .await
179        .migration_err("error asserting migrations table", None)?;
180
181        let applied_migrations = self
182            .get_applied_migrations(migration_table_name)
183            .await
184            .migration_err("error getting current schema version", None)?;
185
186        let migrations = verify_migrations(
187            applied_migrations,
188            migrations.to_vec(),
189            abort_divergent,
190            abort_missing,
191        )?;
192
193        if migrations.is_empty() {
194            log::info!("no migrations to apply");
195        }
196
197        if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
198            migrate_grouped(self, migrations, target, migration_table_name).await
199        } else {
200            migrate(self, migrations, target, migration_table_name).await
201        }
202    }
203}