refinery_core/traits/
sync.rs

1use crate::error::WrapMigrationError;
2use crate::traits::{
3    insert_migration_query, verify_migrations, ASSERT_MIGRATIONS_TABLE_QUERY,
4    GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY,
5};
6use crate::{Error, Migration, Report, Target};
7
8pub trait Transaction {
9    type Error: std::error::Error + Send + Sync + 'static;
10
11    fn execute(&mut self, queries: &[&str]) -> Result<usize, Self::Error>;
12}
13
14pub trait Query<T>: Transaction {
15    fn query(&mut self, query: &str) -> Result<T, Self::Error>;
16}
17
18pub fn migrate<T: Transaction>(
19    transaction: &mut T,
20    migrations: Vec<Migration>,
21    target: Target,
22    migration_table_name: &str,
23    batched: bool,
24) -> Result<Report, Error> {
25    let mut migration_batch = Vec::new();
26    let mut applied_migrations = Vec::new();
27
28    for mut migration in migrations.into_iter() {
29        if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
30            if input_target < migration.version() {
31                log::info!(
32                    "stopping at migration: {}, due to user option",
33                    input_target
34                );
35                break;
36            }
37        }
38
39        log::info!("applying migration: {}", migration);
40        migration.set_applied();
41        let insert_migration = insert_migration_query(&migration, migration_table_name);
42        let migration_sql = migration.sql().expect("sql must be Some!").to_string();
43
44        // If Target is Fake, we only update schema migrations table
45        if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
46            applied_migrations.push(migration);
47            migration_batch.push(migration_sql);
48        }
49        migration_batch.push(insert_migration);
50    }
51
52    match (target, batched) {
53        (Target::Fake | Target::FakeVersion(_), _) => {
54            log::info!("not going to apply any migration as fake flag is enabled");
55        }
56        (Target::Latest | Target::Version(_), true) => {
57            log::info!(
58                "going to apply batch migrations in single transaction: {:#?}",
59                applied_migrations.iter().map(ToString::to_string)
60            );
61        }
62        (Target::Latest | Target::Version(_), false) => {
63            log::info!(
64                "preparing to apply {} migrations: {:#?}",
65                applied_migrations.len(),
66                applied_migrations.iter().map(ToString::to_string)
67            );
68        }
69    };
70
71    let refs: Vec<&str> = migration_batch.iter().map(AsRef::as_ref).collect();
72
73    if batched {
74        transaction
75            .execute(refs.as_ref())
76            .migration_err("error applying migrations", None)?;
77    } else {
78        for (i, update) in refs.iter().enumerate() {
79            transaction
80                .execute(&[update])
81                .migration_err("error applying update", Some(&applied_migrations[0..i / 2]))?;
82        }
83    }
84
85    Ok(Report::new(applied_migrations))
86}
87
88pub trait Migrate: Query<Vec<Migration>>
89where
90    Self: Sized,
91{
92    // Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table
93    fn assert_migrations_table_query(migration_table_name: &str) -> String {
94        ASSERT_MIGRATIONS_TABLE_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
95    }
96
97    fn get_last_applied_migration_query(migration_table_name: &str) -> String {
98        GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
99    }
100
101    fn get_applied_migrations_query(migration_table_name: &str) -> String {
102        GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
103    }
104
105    fn assert_migrations_table(&mut self, migration_table_name: &str) -> Result<usize, Error> {
106        // Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table,
107        // thou on this case it's just to be consistent with the async trait `AsyncMigrate`
108        self.execute(&[Self::assert_migrations_table_query(migration_table_name).as_str()])
109            .migration_err("error asserting migrations table", None)
110    }
111
112    fn get_last_applied_migration(
113        &mut self,
114        migration_table_name: &str,
115    ) -> Result<Option<Migration>, Error> {
116        let mut migrations = self
117            .query(Self::get_last_applied_migration_query(migration_table_name).as_str())
118            .migration_err("error getting last applied migration", None)?;
119
120        Ok(migrations.pop())
121    }
122
123    fn get_applied_migrations(
124        &mut self,
125        migration_table_name: &str,
126    ) -> Result<Vec<Migration>, Error> {
127        let migrations = self
128            .query(Self::get_applied_migrations_query(migration_table_name).as_str())
129            .migration_err("error getting applied migrations", None)?;
130
131        Ok(migrations)
132    }
133
134    fn get_unapplied_migrations(
135        &mut self,
136        migrations: &[Migration],
137        abort_divergent: bool,
138        abort_missing: bool,
139        migration_table_name: &str,
140    ) -> Result<Vec<Migration>, Error> {
141        self.assert_migrations_table(migration_table_name)?;
142
143        let applied_migrations = self.get_applied_migrations(migration_table_name)?;
144
145        let migrations = verify_migrations(
146            applied_migrations,
147            migrations.to_vec(),
148            abort_divergent,
149            abort_missing,
150        )?;
151
152        if migrations.is_empty() {
153            log::info!("no migrations to apply");
154        }
155
156        Ok(migrations)
157    }
158
159    fn migrate(
160        &mut self,
161        migrations: &[Migration],
162        abort_divergent: bool,
163        abort_missing: bool,
164        grouped: bool,
165        target: Target,
166        migration_table_name: &str,
167    ) -> Result<Report, Error> {
168        let migrations = self.get_unapplied_migrations(
169            migrations,
170            abort_divergent,
171            abort_missing,
172            migration_table_name,
173        )?;
174
175        if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
176            migrate(self, migrations, target, migration_table_name, true)
177        } else {
178            migrate(self, migrations, target, migration_table_name, false)
179        }
180    }
181}