refinery_core/traits/
async.rs1use crate::error::WrapMigrationError;
2use crate::traits::{
3 insert_migration_query, verify_migrations, GET_APPLIED_MIGRATIONS_QUERY,
4 GET_LAST_APPLIED_MIGRATION_QUERY,
5};
6use crate::{Error, Migration, Report, Target};
7
8use async_trait::async_trait;
9use std::string::ToString;
10
11#[async_trait]
12pub trait AsyncTransaction {
13 type Error: std::error::Error + Send + Sync + 'static;
14
15 async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
16 &mut self,
17 queries: T,
18 ) -> Result<usize, Self::Error>;
19}
20
21#[async_trait]
22pub trait AsyncQuery<T>: AsyncTransaction {
23 async fn query(&mut self, query: &str) -> Result<T, Self::Error>;
24}
25
26async fn migrate<T: AsyncTransaction>(
27 transaction: &mut T,
28 migrations: Vec<Migration>,
29 target: Target,
30 migration_table_name: &str,
31) -> Result<Report, Error> {
32 let mut applied_migrations = vec![];
33
34 for mut migration in migrations.into_iter() {
35 if let Target::Version(input_target) = target {
36 if input_target < migration.version() {
37 log::info!(
38 "stopping at migration: {}, due to user option",
39 input_target
40 );
41 break;
42 }
43 }
44
45 log::info!("applying migration: {}", migration);
46 migration.set_applied();
47 let update_query = insert_migration_query(&migration, migration_table_name);
48 transaction
49 .execute(
50 [
51 migration.sql().as_ref().expect("sql must be Some!"),
52 update_query.as_str(),
53 ]
54 .into_iter(),
55 )
56 .await
57 .migration_err(
58 &format!("error applying migration {migration}"),
59 Some(&applied_migrations),
60 )?;
61 applied_migrations.push(migration);
62 }
63 Ok(Report::new(applied_migrations))
64}
65
66async fn migrate_grouped<T: AsyncTransaction>(
67 transaction: &mut T,
68 migrations: Vec<Migration>,
69 target: Target,
70 migration_table_name: &str,
71) -> Result<Report, Error> {
72 let mut grouped_migrations = Vec::new();
73 let mut applied_migrations = Vec::new();
74
75 for mut migration in migrations.into_iter() {
76 if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
77 if input_target < migration.version() {
78 break;
79 }
80 }
81
82 migration.set_applied();
83 let query = insert_migration_query(&migration, migration_table_name);
84
85 let sql = migration.sql().expect("sql must be Some!").to_string();
86
87 if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
89 applied_migrations.push(migration);
90 grouped_migrations.push(sql);
91 }
92 grouped_migrations.push(query);
93 }
94
95 match target {
96 Target::Fake | Target::FakeVersion(_) => {
97 log::info!("not going to apply any migration as fake flag is enabled");
98 }
99 Target::Latest | Target::Version(_) => {
100 let migrations_display = applied_migrations
101 .iter()
102 .map(ToString::to_string)
103 .collect::<Vec<String>>()
104 .join("\n");
105 log::info!(
106 "going to apply batch migrations in single transaction:\n{migrations_display}"
107 );
108 }
109 };
110
111 if let Target::Version(input_target) = target {
112 log::info!(
113 "stopping at migration: {}, due to user option",
114 input_target
115 );
116 }
117
118 transaction
119 .execute(grouped_migrations.iter().map(AsRef::as_ref))
120 .await
121 .migration_err("error applying migrations", None)?;
122
123 Ok(Report::new(applied_migrations))
124}
125
126#[async_trait]
127pub trait AsyncMigrate: AsyncQuery<Vec<Migration>>
128where
129 Self: Sized,
130{
131 fn assert_migrations_table_query(migration_table_name: &str) -> String {
133 super::assert_migrations_table_query(migration_table_name)
134 }
135
136 fn get_last_applied_migration_query(migration_table_name: &str) -> String {
137 GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
138 }
139
140 fn get_applied_migrations_query(migration_table_name: &str) -> String {
141 GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
142 }
143
144 async fn get_last_applied_migration(
145 &mut self,
146 migration_table_name: &str,
147 ) -> Result<Option<Migration>, Error> {
148 let mut migrations = self
149 .query(Self::get_last_applied_migration_query(migration_table_name).as_ref())
150 .await
151 .migration_err("error getting last applied migration", None)?;
152
153 Ok(migrations.pop())
154 }
155
156 async fn get_applied_migrations(
157 &mut self,
158 migration_table_name: &str,
159 ) -> Result<Vec<Migration>, Error> {
160 let migrations = self
161 .query(Self::get_applied_migrations_query(migration_table_name).as_ref())
162 .await
163 .migration_err("error getting applied migrations", None)?;
164
165 Ok(migrations)
166 }
167
168 async fn migrate(
169 &mut self,
170 migrations: &[Migration],
171 abort_divergent: bool,
172 abort_missing: bool,
173 grouped: bool,
174 target: Target,
175 migration_table_name: &str,
176 ) -> Result<Report, Error> {
177 self.execute(
178 [Self::assert_migrations_table_query(migration_table_name).as_ref()].into_iter(),
179 )
180 .await
181 .migration_err("error asserting migrations table", None)?;
182
183 let applied_migrations = self
184 .get_applied_migrations(migration_table_name)
185 .await
186 .migration_err("error getting current schema version", None)?;
187
188 let migrations = verify_migrations(
189 applied_migrations,
190 migrations.to_vec(),
191 abort_divergent,
192 abort_missing,
193 )?;
194
195 if migrations.is_empty() {
196 log::info!("no migrations to apply");
197 }
198
199 if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
200 migrate_grouped(self, migrations, target, migration_table_name).await
201 } else {
202 migrate(self, migrations, target, migration_table_name).await
203 }
204 }
205}