refinery_core/traits/
async.rs1use crate::error::WrapMigrationError;
2use crate::traits::{
3 insert_migration_query, verify_migrations, ASSERT_MIGRATIONS_TABLE_QUERY,
4 GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY,
5};
6use crate::{Error, Migration, Report, Target};
7
8use async_trait::async_trait;
9use std::string::ToString;
10
11#[async_trait]
12pub trait AsyncTransaction {
13 type Error: std::error::Error + Send + Sync + 'static;
14
15 async fn execute(&mut self, query: &[&str]) -> Result<usize, Self::Error>;
16}
17
18#[async_trait]
19pub trait AsyncQuery<T>: AsyncTransaction {
20 async fn query(&mut self, query: &str) -> Result<T, Self::Error>;
21}
22
23async fn migrate<T: AsyncTransaction>(
24 transaction: &mut T,
25 migrations: Vec<Migration>,
26 target: Target,
27 migration_table_name: &str,
28) -> Result<Report, Error> {
29 let mut applied_migrations = vec![];
30
31 for mut migration in migrations.into_iter() {
32 if let Target::Version(input_target) = target {
33 if input_target < migration.version() {
34 log::info!(
35 "stopping at migration: {}, due to user option",
36 input_target
37 );
38 break;
39 }
40 }
41
42 log::info!("applying migration: {}", migration);
43 migration.set_applied();
44 let update_query = insert_migration_query(&migration, migration_table_name);
45 transaction
46 .execute(&[
47 migration.sql().as_ref().expect("sql must be Some!"),
48 &update_query,
49 ])
50 .await
51 .migration_err(
52 &format!("error applying migration {}", migration),
53 Some(&applied_migrations),
54 )?;
55 applied_migrations.push(migration);
56 }
57 Ok(Report::new(applied_migrations))
58}
59
60async fn migrate_grouped<T: AsyncTransaction>(
61 transaction: &mut T,
62 migrations: Vec<Migration>,
63 target: Target,
64 migration_table_name: &str,
65) -> Result<Report, Error> {
66 let mut grouped_migrations = Vec::new();
67 let mut applied_migrations = Vec::new();
68
69 for mut migration in migrations.into_iter() {
70 if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
71 if input_target < migration.version() {
72 break;
73 }
74 }
75
76 migration.set_applied();
77 let query = insert_migration_query(&migration, migration_table_name);
78
79 let sql = migration.sql().expect("sql must be Some!").to_string();
80
81 if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
83 applied_migrations.push(migration);
84 grouped_migrations.push(sql);
85 }
86 grouped_migrations.push(query);
87 }
88
89 match target {
90 Target::Fake | Target::FakeVersion(_) => {
91 log::info!("not going to apply any migration as fake flag is enabled");
92 }
93 Target::Latest | Target::Version(_) => {
94 log::info!(
95 "going to apply batch migrations in single transaction: {:#?}",
96 applied_migrations.iter().map(ToString::to_string)
97 );
98 }
99 };
100
101 if let Target::Version(input_target) = target {
102 log::info!(
103 "stopping at migration: {}, due to user option",
104 input_target
105 );
106 }
107
108 let refs: Vec<&str> = grouped_migrations.iter().map(AsRef::as_ref).collect();
109
110 transaction
111 .execute(refs.as_ref())
112 .await
113 .migration_err("error applying migrations", None)?;
114
115 Ok(Report::new(applied_migrations))
116}
117
118#[async_trait]
119pub trait AsyncMigrate: AsyncQuery<Vec<Migration>>
120where
121 Self: Sized,
122{
123 fn assert_migrations_table_query(migration_table_name: &str) -> String {
125 ASSERT_MIGRATIONS_TABLE_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
126 }
127
128 fn get_last_applied_migration_query(migration_table_name: &str) -> String {
129 GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
130 }
131
132 fn get_applied_migrations_query(migration_table_name: &str) -> String {
133 GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
134 }
135
136 async fn get_last_applied_migration(
137 &mut self,
138 migration_table_name: &str,
139 ) -> Result<Option<Migration>, Error> {
140 let mut migrations = self
141 .query(Self::get_last_applied_migration_query(migration_table_name).as_str())
142 .await
143 .migration_err("error getting last applied migration", None)?;
144
145 Ok(migrations.pop())
146 }
147
148 async fn get_applied_migrations(
149 &mut self,
150 migration_table_name: &str,
151 ) -> Result<Vec<Migration>, Error> {
152 let migrations = self
153 .query(Self::get_applied_migrations_query(migration_table_name).as_str())
154 .await
155 .migration_err("error getting applied migrations", None)?;
156
157 Ok(migrations)
158 }
159
160 async fn migrate(
161 &mut self,
162 migrations: &[Migration],
163 abort_divergent: bool,
164 abort_missing: bool,
165 grouped: bool,
166 target: Target,
167 migration_table_name: &str,
168 ) -> Result<Report, Error> {
169 self.execute(&[&Self::assert_migrations_table_query(migration_table_name)])
170 .await
171 .migration_err("error asserting migrations table", None)?;
172
173 let applied_migrations = self
174 .get_applied_migrations(migration_table_name)
175 .await
176 .migration_err("error getting current schema version", None)?;
177
178 let migrations = verify_migrations(
179 applied_migrations,
180 migrations.to_vec(),
181 abort_divergent,
182 abort_missing,
183 )?;
184
185 if migrations.is_empty() {
186 log::info!("no migrations to apply");
187 }
188
189 if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
190 migrate_grouped(self, migrations, target, migration_table_name).await
191 } else {
192 migrate(self, migrations, target, migration_table_name).await
193 }
194 }
195}