Skip to main content

trailbase_refinery/traits/
async.rs

1use crate::error::WrapMigrationError;
2use crate::traits::{
3  ASSERT_MIGRATIONS_TABLE_QUERY, GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY,
4  insert_migration_query, verify_migrations,
5};
6use crate::{Error, Migration, Report, Target};
7
8use async_trait::async_trait;
9use std::string::ToString;
10
11#[async_trait]
12pub trait AsyncTransaction {
13  type Error: std::error::Error + Send + Sync + 'static;
14
15  async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
16    &mut self,
17    queries: T,
18  ) -> Result<usize, Self::Error>;
19}
20
21#[async_trait]
22pub trait AsyncQuery<T>: AsyncTransaction {
23  async fn query(&mut self, query: &str) -> Result<T, Self::Error>;
24}
25
26async fn migrate<T: AsyncTransaction>(
27  transaction: &mut T,
28  migrations: Vec<Migration>,
29  target: Target,
30  migration_table_name: &str,
31) -> Result<Report, Error> {
32  let mut applied_migrations = vec![];
33
34  for mut migration in migrations.into_iter() {
35    if let Target::Version(input_target) = target {
36      if input_target < migration.version() {
37        log::info!(
38          "stopping at migration: {}, due to user option",
39          input_target
40        );
41        break;
42      }
43    }
44
45    log::info!("applying migration: {}", migration);
46    migration.set_applied();
47    let update_query = insert_migration_query(&migration, migration_table_name);
48    transaction
49      .execute(
50        [
51          migration.sql().as_ref().expect("sql must be Some!"),
52          update_query.as_str(),
53        ]
54        .into_iter(),
55      )
56      .await
57      .migration_err(
58        &format!("error applying migration {}", migration),
59        Some(&applied_migrations),
60      )?;
61    applied_migrations.push(migration);
62  }
63  Ok(Report::new(applied_migrations))
64}
65
66async fn migrate_grouped<T: AsyncTransaction>(
67  transaction: &mut T,
68  migrations: Vec<Migration>,
69  target: Target,
70  migration_table_name: &str,
71) -> Result<Report, Error> {
72  let mut grouped_migrations = Vec::new();
73  let mut applied_migrations = Vec::new();
74
75  for mut migration in migrations.into_iter() {
76    if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
77      if input_target < migration.version() {
78        break;
79      }
80    }
81
82    migration.set_applied();
83    let query = insert_migration_query(&migration, migration_table_name);
84
85    let sql = migration.sql().expect("sql must be Some!").to_string();
86
87    // If Target is Fake, we only update schema migrations table
88    if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
89      applied_migrations.push(migration);
90      grouped_migrations.push(sql);
91    }
92    grouped_migrations.push(query);
93  }
94
95  match target {
96    Target::Fake | Target::FakeVersion(_) => {
97      log::info!("not going to apply any migration as fake flag is enabled");
98    }
99    Target::Latest | Target::Version(_) => {
100      log::info!(
101        "going to apply batch migrations in single transaction: {:#?}",
102        applied_migrations.iter().map(ToString::to_string)
103      );
104    }
105  };
106
107  if let Target::Version(input_target) = target {
108    log::info!(
109      "stopping at migration: {}, due to user option",
110      input_target
111    );
112  }
113
114  let refs = grouped_migrations.iter().map(AsRef::as_ref);
115
116  transaction
117    .execute(refs)
118    .await
119    .migration_err("error applying migrations", None)?;
120
121  Ok(Report::new(applied_migrations))
122}
123
124#[async_trait]
125pub trait AsyncMigrate: AsyncQuery<Vec<Migration>>
126where
127  Self: Sized,
128{
129  // Needed cause some database vendors like Mssql have a non sql standard way of checking the
130  // migrations table
131  fn assert_migrations_table_query(migration_table_name: &str) -> String {
132    ASSERT_MIGRATIONS_TABLE_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
133  }
134
135  async fn get_last_applied_migration(
136    &mut self,
137    migration_table_name: &str,
138  ) -> Result<Option<Migration>, Error> {
139    let mut migrations = self
140      .query(
141        &GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name),
142      )
143      .await
144      .migration_err("error getting last applied migration", None)?;
145
146    Ok(migrations.pop())
147  }
148
149  async fn get_applied_migrations(
150    &mut self,
151    migration_table_name: &str,
152  ) -> Result<Vec<Migration>, Error> {
153    let migrations = self
154      .query(&GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name))
155      .await
156      .migration_err("error getting applied migrations", None)?;
157
158    Ok(migrations)
159  }
160
161  async fn migrate(
162    &mut self,
163    migrations: &[Migration],
164    abort_divergent: bool,
165    abort_missing: bool,
166    grouped: bool,
167    target: Target,
168    migration_table_name: &str,
169  ) -> Result<Report, Error> {
170    self
171      .execute([Self::assert_migrations_table_query(migration_table_name).as_str()].into_iter())
172      .await
173      .migration_err("error asserting migrations table", None)?;
174
175    let applied_migrations = self
176      .query(&GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name))
177      .await
178      .migration_err("error getting current schema version", None)?;
179
180    let migrations = verify_migrations(
181      applied_migrations,
182      migrations.to_vec(),
183      abort_divergent,
184      abort_missing,
185    )?;
186
187    if migrations.is_empty() {
188      log::info!("no migrations to apply");
189    }
190
191    if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
192      migrate_grouped(self, migrations, target, migration_table_name).await
193    } else {
194      migrate(self, migrations, target, migration_table_name).await
195    }
196  }
197}