refinery_core/traits/
async.rs

1use crate::error::WrapMigrationError;
2use crate::traits::{
3    insert_migration_query, verify_migrations, GET_APPLIED_MIGRATIONS_QUERY,
4    GET_LAST_APPLIED_MIGRATION_QUERY,
5};
6use crate::{Error, Migration, Report, Target};
7
8use async_trait::async_trait;
9use std::string::ToString;
10
11#[async_trait]
12pub trait AsyncTransaction {
13    type Error: std::error::Error + Send + Sync + 'static;
14
15    async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
16        &mut self,
17        queries: T,
18    ) -> Result<usize, Self::Error>;
19}
20
21#[async_trait]
22pub trait AsyncQuery<T>: AsyncTransaction {
23    async fn query(&mut self, query: &str) -> Result<T, Self::Error>;
24}
25
26async fn migrate<T: AsyncTransaction>(
27    transaction: &mut T,
28    migrations: Vec<Migration>,
29    target: Target,
30    migration_table_name: &str,
31) -> Result<Report, Error> {
32    let mut applied_migrations = vec![];
33
34    for mut migration in migrations.into_iter() {
35        if let Target::Version(input_target) = target {
36            if input_target < migration.version() {
37                log::info!(
38                    "stopping at migration: {}, due to user option",
39                    input_target
40                );
41                break;
42            }
43        }
44
45        log::info!("applying migration: {}", migration);
46        migration.set_applied();
47        let update_query = insert_migration_query(&migration, migration_table_name);
48        transaction
49            .execute(
50                [
51                    migration.sql().as_ref().expect("sql must be Some!"),
52                    update_query.as_str(),
53                ]
54                .into_iter(),
55            )
56            .await
57            .migration_err(
58                &format!("error applying migration {migration}"),
59                Some(&applied_migrations),
60            )?;
61        applied_migrations.push(migration);
62    }
63    Ok(Report::new(applied_migrations))
64}
65
66async fn migrate_grouped<T: AsyncTransaction>(
67    transaction: &mut T,
68    migrations: Vec<Migration>,
69    target: Target,
70    migration_table_name: &str,
71) -> Result<Report, Error> {
72    let mut grouped_migrations = Vec::new();
73    let mut applied_migrations = Vec::new();
74
75    for mut migration in migrations.into_iter() {
76        if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
77            if input_target < migration.version() {
78                break;
79            }
80        }
81
82        migration.set_applied();
83        let query = insert_migration_query(&migration, migration_table_name);
84
85        let sql = migration.sql().expect("sql must be Some!").to_string();
86
87        // If Target is Fake, we only update schema migrations table
88        if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
89            applied_migrations.push(migration);
90            grouped_migrations.push(sql);
91        }
92        grouped_migrations.push(query);
93    }
94
95    match target {
96        Target::Fake | Target::FakeVersion(_) => {
97            log::info!("not going to apply any migration as fake flag is enabled");
98        }
99        Target::Latest | Target::Version(_) => {
100            let migrations_display = applied_migrations
101                .iter()
102                .map(ToString::to_string)
103                .collect::<Vec<String>>()
104                .join("\n");
105            log::info!(
106                "going to apply batch migrations in single transaction:\n{migrations_display}"
107            );
108        }
109    };
110
111    if let Target::Version(input_target) = target {
112        log::info!(
113            "stopping at migration: {}, due to user option",
114            input_target
115        );
116    }
117
118    transaction
119        .execute(grouped_migrations.iter().map(AsRef::as_ref))
120        .await
121        .migration_err("error applying migrations", None)?;
122
123    Ok(Report::new(applied_migrations))
124}
125
126#[async_trait]
127pub trait AsyncMigrate: AsyncQuery<Vec<Migration>>
128where
129    Self: Sized,
130{
131    // Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table
132    fn assert_migrations_table_query(migration_table_name: &str) -> String {
133        super::assert_migrations_table_query(migration_table_name)
134    }
135
136    fn get_last_applied_migration_query(migration_table_name: &str) -> String {
137        GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
138    }
139
140    fn get_applied_migrations_query(migration_table_name: &str) -> String {
141        GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
142    }
143
144    async fn get_last_applied_migration(
145        &mut self,
146        migration_table_name: &str,
147    ) -> Result<Option<Migration>, Error> {
148        let mut migrations = self
149            .query(Self::get_last_applied_migration_query(migration_table_name).as_ref())
150            .await
151            .migration_err("error getting last applied migration", None)?;
152
153        Ok(migrations.pop())
154    }
155
156    async fn get_applied_migrations(
157        &mut self,
158        migration_table_name: &str,
159    ) -> Result<Vec<Migration>, Error> {
160        let migrations = self
161            .query(Self::get_applied_migrations_query(migration_table_name).as_ref())
162            .await
163            .migration_err("error getting applied migrations", None)?;
164
165        Ok(migrations)
166    }
167
168    async fn migrate(
169        &mut self,
170        migrations: &[Migration],
171        abort_divergent: bool,
172        abort_missing: bool,
173        grouped: bool,
174        target: Target,
175        migration_table_name: &str,
176    ) -> Result<Report, Error> {
177        self.execute(
178            [Self::assert_migrations_table_query(migration_table_name).as_ref()].into_iter(),
179        )
180        .await
181        .migration_err("error asserting migrations table", None)?;
182
183        let applied_migrations = self
184            .get_applied_migrations(migration_table_name)
185            .await
186            .migration_err("error getting current schema version", None)?;
187
188        let migrations = verify_migrations(
189            applied_migrations,
190            migrations.to_vec(),
191            abort_divergent,
192            abort_missing,
193        )?;
194
195        if migrations.is_empty() {
196            log::info!("no migrations to apply");
197        }
198
199        if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
200            migrate_grouped(self, migrations, target, migration_table_name).await
201        } else {
202            migrate(self, migrations, target, migration_table_name).await
203        }
204    }
205}