Skip to main content

trailbase_refinery/traits/
sync.rs

1use std::ops::Deref;
2
3use crate::error::WrapMigrationError;
4use crate::traits::{
5  ASSERT_MIGRATIONS_TABLE_QUERY, GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY,
6  insert_migration_query, verify_migrations,
7};
8use crate::{Error, Migration, Report, Target};
9
10pub trait Transaction {
11  type Error: std::error::Error + Send + Sync + 'static;
12
13  fn execute<'a, T: Iterator<Item = &'a str>>(&mut self, queries: T) -> Result<usize, Self::Error>;
14}
15
16pub trait Query<T>: Transaction {
17  fn query(&mut self, query: &str) -> Result<T, Self::Error>;
18}
19
20pub fn migrate<T: Transaction>(
21  transaction: &mut T,
22  migrations: Vec<Migration>,
23  target: Target,
24  migration_table_name: &str,
25  grouped: bool,
26) -> Result<Report, Error> {
27  let mut migration_batch = Vec::new();
28  let mut applied_migrations = Vec::new();
29
30  for mut migration in migrations.into_iter() {
31    if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
32      if input_target < migration.version() {
33        log::info!(
34          "stopping at migration: {}, due to user option",
35          input_target
36        );
37        break;
38      }
39    }
40
41    log::info!("applying migration: {}", migration);
42    migration.set_applied();
43    let insert_migration = insert_migration_query(&migration, migration_table_name);
44    let migration_sql = migration.sql().expect("sql must be Some!").to_string();
45
46    // If Target is Fake, we only update schema migrations table
47    if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
48      applied_migrations.push(migration);
49      migration_batch.push(migration_sql);
50    }
51    migration_batch.push(insert_migration);
52  }
53
54  match (target, grouped) {
55    (Target::Fake | Target::FakeVersion(_), _) => {
56      log::info!("not going to apply any migration as fake flag is enabled");
57    }
58    (Target::Latest | Target::Version(_), true) => {
59      log::info!(
60        "going to apply batch migrations in single transaction: {:#?}",
61        applied_migrations.iter().map(ToString::to_string)
62      );
63    }
64    (Target::Latest | Target::Version(_), false) => {
65      log::info!(
66        "preparing to apply {} migrations: {:#?}",
67        applied_migrations.len(),
68        applied_migrations.iter().map(ToString::to_string)
69      );
70    }
71  };
72
73  if grouped {
74    transaction
75      .execute(migration_batch.iter().map(Deref::deref))
76      .migration_err("error applying migrations", None)?;
77  } else {
78    for (i, update) in migration_batch.into_iter().enumerate() {
79      transaction
80        .execute([update.as_str()].into_iter())
81        .migration_err("error applying update", Some(&applied_migrations[0..i / 2]))?;
82    }
83  }
84
85  Ok(Report::new(applied_migrations))
86}
87
88pub trait Migrate: Query<Vec<Migration>>
89where
90  Self: Sized,
91{
92  fn assert_migrations_table(&mut self, migration_table_name: &str) -> Result<usize, Error> {
93    // Needed cause some database vendors like Mssql have a non sql standard way of checking the
94    // migrations table, thou on this case it's just to be consistent with the async trait
95    // `AsyncMigrate`
96    self
97      .execute(
98        [ASSERT_MIGRATIONS_TABLE_QUERY
99          .replace("%MIGRATION_TABLE_NAME%", migration_table_name)
100          .as_str()]
101        .into_iter(),
102      )
103      .migration_err("error asserting migrations table", None)
104  }
105
106  fn get_last_applied_migration(
107    &mut self,
108    migration_table_name: &str,
109  ) -> Result<Option<Migration>, Error> {
110    let mut migrations = self
111      .query(
112        &GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name),
113      )
114      .migration_err("error getting last applied migration", None)?;
115
116    Ok(migrations.pop())
117  }
118
119  fn get_applied_migrations(
120    &mut self,
121    migration_table_name: &str,
122  ) -> Result<Vec<Migration>, Error> {
123    let migrations = self
124      .query(&GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name))
125      .migration_err("error getting applied migrations", None)?;
126
127    Ok(migrations)
128  }
129
130  fn get_unapplied_migrations(
131    &mut self,
132    migrations: &[Migration],
133    abort_divergent: bool,
134    abort_missing: bool,
135    migration_table_name: &str,
136  ) -> Result<Vec<Migration>, Error> {
137    self.assert_migrations_table(migration_table_name)?;
138
139    let applied_migrations = self.get_applied_migrations(migration_table_name)?;
140
141    let migrations = verify_migrations(
142      applied_migrations,
143      migrations.to_vec(),
144      abort_divergent,
145      abort_missing,
146    )?;
147
148    if migrations.is_empty() {
149      log::info!("no migrations to apply");
150    }
151
152    Ok(migrations)
153  }
154
155  fn migrate(
156    &mut self,
157    migrations: &[Migration],
158    abort_divergent: bool,
159    abort_missing: bool,
160    grouped: bool,
161    target: Target,
162    migration_table_name: &str,
163  ) -> Result<Report, Error> {
164    let migrations = self.get_unapplied_migrations(
165      migrations,
166      abort_divergent,
167      abort_missing,
168      migration_table_name,
169    )?;
170
171    if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
172      migrate(self, migrations, target, migration_table_name, true)
173    } else {
174      migrate(self, migrations, target, migration_table_name, false)
175    }
176  }
177}