refinery_core/traits/
sync.rs

1use std::ops::Deref;
2
3use crate::error::WrapMigrationError;
4use crate::traits::{
5    insert_migration_query, verify_migrations, ASSERT_MIGRATIONS_TABLE_QUERY,
6    GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY,
7};
8use crate::{Error, Migration, Report, Target};
9
10pub trait Transaction {
11    type Error: std::error::Error + Send + Sync + 'static;
12
13    fn execute<'a, T: Iterator<Item = &'a str>>(
14        &mut self,
15        queries: T,
16    ) -> Result<usize, Self::Error>;
17}
18
19pub trait Query<T>: Transaction {
20    fn query(&mut self, query: &str) -> Result<T, Self::Error>;
21}
22
23pub fn migrate<T: Transaction>(
24    transaction: &mut T,
25    migrations: Vec<Migration>,
26    target: Target,
27    migration_table_name: &str,
28    grouped: bool,
29) -> Result<Report, Error> {
30    let mut migration_batch = Vec::new();
31    let mut applied_migrations = Vec::new();
32
33    for mut migration in migrations.into_iter() {
34        if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
35            if input_target < migration.version() {
36                log::info!(
37                    "stopping at migration: {}, due to user option",
38                    input_target
39                );
40                break;
41            }
42        }
43
44        log::info!("applying migration: {}", migration);
45        migration.set_applied();
46        let insert_migration = insert_migration_query(&migration, migration_table_name);
47        let migration_sql = migration.sql().expect("sql must be Some!").to_string();
48
49        // If Target is Fake, we only update schema migrations table
50        if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
51            applied_migrations.push(migration);
52            migration_batch.push(migration_sql);
53        }
54        migration_batch.push(insert_migration);
55    }
56
57    match (target, grouped) {
58        (Target::Fake | Target::FakeVersion(_), _) => {
59            log::info!("not going to apply any migration as fake flag is enabled");
60        }
61        (Target::Latest | Target::Version(_), true) => {
62            log::info!(
63                "going to apply batch migrations in single transaction: {:#?}",
64                applied_migrations.iter().map(ToString::to_string)
65            );
66        }
67        (Target::Latest | Target::Version(_), false) => {
68            log::info!(
69                "preparing to apply {} migrations: {:#?}",
70                applied_migrations.len(),
71                applied_migrations.iter().map(ToString::to_string)
72            );
73        }
74    };
75
76    if grouped {
77        transaction
78            .execute(migration_batch.iter().map(Deref::deref))
79            .migration_err("error applying migrations", None)?;
80    } else {
81        for (i, update) in migration_batch.into_iter().enumerate() {
82            transaction
83                .execute([update.as_str()].into_iter())
84                .migration_err("error applying update", Some(&applied_migrations[0..i / 2]))?;
85        }
86    }
87
88    Ok(Report::new(applied_migrations))
89}
90
91pub trait Migrate: Query<Vec<Migration>>
92where
93    Self: Sized,
94{
95    // Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table
96    fn assert_migrations_table_query(migration_table_name: &str) -> String {
97        ASSERT_MIGRATIONS_TABLE_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
98    }
99
100    fn get_last_applied_migration_query(migration_table_name: &str) -> String {
101        GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
102    }
103
104    fn get_applied_migrations_query(migration_table_name: &str) -> String {
105        GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
106    }
107
108    fn assert_migrations_table(&mut self, migration_table_name: &str) -> Result<usize, Error> {
109        // Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table,
110        // thou on this case it's just to be consistent with the async trait `AsyncMigrate`
111        self.execute(
112            [Self::assert_migrations_table_query(migration_table_name).as_str()].into_iter(),
113        )
114        .migration_err("error asserting migrations table", None)
115    }
116
117    fn get_last_applied_migration(
118        &mut self,
119        migration_table_name: &str,
120    ) -> Result<Option<Migration>, Error> {
121        let mut migrations = self
122            .query(Self::get_last_applied_migration_query(migration_table_name).as_str())
123            .migration_err("error getting last applied migration", None)?;
124
125        Ok(migrations.pop())
126    }
127
128    fn get_applied_migrations(
129        &mut self,
130        migration_table_name: &str,
131    ) -> Result<Vec<Migration>, Error> {
132        let migrations = self
133            .query(Self::get_applied_migrations_query(migration_table_name).as_str())
134            .migration_err("error getting applied migrations", None)?;
135
136        Ok(migrations)
137    }
138
139    fn get_unapplied_migrations(
140        &mut self,
141        migrations: &[Migration],
142        abort_divergent: bool,
143        abort_missing: bool,
144        migration_table_name: &str,
145    ) -> Result<Vec<Migration>, Error> {
146        self.assert_migrations_table(migration_table_name)?;
147
148        let applied_migrations = self.get_applied_migrations(migration_table_name)?;
149
150        let migrations = verify_migrations(
151            applied_migrations,
152            migrations.to_vec(),
153            abort_divergent,
154            abort_missing,
155        )?;
156
157        if migrations.is_empty() {
158            log::info!("no migrations to apply");
159        }
160
161        Ok(migrations)
162    }
163
164    fn migrate(
165        &mut self,
166        migrations: &[Migration],
167        abort_divergent: bool,
168        abort_missing: bool,
169        grouped: bool,
170        target: Target,
171        migration_table_name: &str,
172    ) -> Result<Report, Error> {
173        let migrations = self.get_unapplied_migrations(
174            migrations,
175            abort_divergent,
176            abort_missing,
177            migration_table_name,
178        )?;
179
180        if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
181            migrate(self, migrations, target, migration_table_name, true)
182        } else {
183            migrate(self, migrations, target, migration_table_name, false)
184        }
185    }
186}