refinery_core/traits/
async.rs1use crate::error::WrapMigrationError;
2use crate::traits::{
3 insert_migration_query, verify_migrations, ASSERT_MIGRATIONS_TABLE_QUERY,
4 GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY,
5};
6use crate::{Error, Migration, Report, Target};
7
8use async_trait::async_trait;
9use std::string::ToString;
10
11#[async_trait]
12pub trait AsyncTransaction {
13 type Error: std::error::Error + Send + Sync + 'static;
14
15 async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
16 &mut self,
17 queries: T,
18 ) -> Result<usize, Self::Error>;
19}
20
21#[async_trait]
22pub trait AsyncQuery<T>: AsyncTransaction {
23 async fn query(&mut self, query: &str) -> Result<T, Self::Error>;
24}
25
26async fn migrate<T: AsyncTransaction>(
27 transaction: &mut T,
28 migrations: Vec<Migration>,
29 target: Target,
30 migration_table_name: &str,
31) -> Result<Report, Error> {
32 let mut applied_migrations = vec![];
33
34 for mut migration in migrations.into_iter() {
35 if let Target::Version(input_target) = target {
36 if input_target < migration.version() {
37 log::info!(
38 "stopping at migration: {}, due to user option",
39 input_target
40 );
41 break;
42 }
43 }
44
45 log::info!("applying migration: {}", migration);
46 migration.set_applied();
47 let update_query = insert_migration_query(&migration, migration_table_name);
48 transaction
49 .execute(
50 [
51 migration.sql().as_ref().expect("sql must be Some!"),
52 update_query.as_str(),
53 ]
54 .into_iter(),
55 )
56 .await
57 .migration_err(
58 &format!("error applying migration {}", migration),
59 Some(&applied_migrations),
60 )?;
61 applied_migrations.push(migration);
62 }
63 Ok(Report::new(applied_migrations))
64}
65
66async fn migrate_grouped<T: AsyncTransaction>(
67 transaction: &mut T,
68 migrations: Vec<Migration>,
69 target: Target,
70 migration_table_name: &str,
71) -> Result<Report, Error> {
72 let mut grouped_migrations = Vec::new();
73 let mut applied_migrations = Vec::new();
74
75 for mut migration in migrations.into_iter() {
76 if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
77 if input_target < migration.version() {
78 break;
79 }
80 }
81
82 migration.set_applied();
83 let query = insert_migration_query(&migration, migration_table_name);
84
85 let sql = migration.sql().expect("sql must be Some!").to_string();
86
87 if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
89 applied_migrations.push(migration);
90 grouped_migrations.push(sql);
91 }
92 grouped_migrations.push(query);
93 }
94
95 match target {
96 Target::Fake | Target::FakeVersion(_) => {
97 log::info!("not going to apply any migration as fake flag is enabled");
98 }
99 Target::Latest | Target::Version(_) => {
100 log::info!(
101 "going to apply batch migrations in single transaction: {:#?}",
102 applied_migrations.iter().map(ToString::to_string)
103 );
104 }
105 };
106
107 if let Target::Version(input_target) = target {
108 log::info!(
109 "stopping at migration: {}, due to user option",
110 input_target
111 );
112 }
113
114 let refs = grouped_migrations.iter().map(AsRef::as_ref);
115
116 transaction
117 .execute(refs)
118 .await
119 .migration_err("error applying migrations", None)?;
120
121 Ok(Report::new(applied_migrations))
122}
123
124#[async_trait]
125pub trait AsyncMigrate: AsyncQuery<Vec<Migration>>
126where
127 Self: Sized,
128{
129 fn assert_migrations_table_query(migration_table_name: &str) -> String {
131 ASSERT_MIGRATIONS_TABLE_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
132 }
133
134 fn get_last_applied_migration_query(migration_table_name: &str) -> String {
135 GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
136 }
137
138 fn get_applied_migrations_query(migration_table_name: &str) -> String {
139 GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
140 }
141
142 async fn get_last_applied_migration(
143 &mut self,
144 migration_table_name: &str,
145 ) -> Result<Option<Migration>, Error> {
146 let mut migrations = self
147 .query(Self::get_last_applied_migration_query(migration_table_name).as_str())
148 .await
149 .migration_err("error getting last applied migration", None)?;
150
151 Ok(migrations.pop())
152 }
153
154 async fn get_applied_migrations(
155 &mut self,
156 migration_table_name: &str,
157 ) -> Result<Vec<Migration>, Error> {
158 let migrations = self
159 .query(Self::get_applied_migrations_query(migration_table_name).as_str())
160 .await
161 .migration_err("error getting applied migrations", None)?;
162
163 Ok(migrations)
164 }
165
166 async fn migrate(
167 &mut self,
168 migrations: &[Migration],
169 abort_divergent: bool,
170 abort_missing: bool,
171 grouped: bool,
172 target: Target,
173 migration_table_name: &str,
174 ) -> Result<Report, Error> {
175 self.execute(
176 [Self::assert_migrations_table_query(migration_table_name).as_str()].into_iter(),
177 )
178 .await
179 .migration_err("error asserting migrations table", None)?;
180
181 let applied_migrations = self
182 .get_applied_migrations(migration_table_name)
183 .await
184 .migration_err("error getting current schema version", None)?;
185
186 let migrations = verify_migrations(
187 applied_migrations,
188 migrations.to_vec(),
189 abort_divergent,
190 abort_missing,
191 )?;
192
193 if migrations.is_empty() {
194 log::info!("no migrations to apply");
195 }
196
197 if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
198 migrate_grouped(self, migrations, target, migration_table_name).await
199 } else {
200 migrate(self, migrations, target, migration_table_name).await
201 }
202 }
203}