refinery_core/traits/
sync.rs

1use crate::error::WrapMigrationError;
2use crate::traits::{
3    insert_migration_query, verify_migrations, GET_APPLIED_MIGRATIONS_QUERY,
4    GET_LAST_APPLIED_MIGRATION_QUERY,
5};
6use crate::{Error, Migration, Report, Target};
7
8pub trait Transaction {
9    type Error: std::error::Error + Send + Sync + 'static;
10
11    fn execute<'a, T: Iterator<Item = &'a str>>(
12        &mut self,
13        queries: T,
14    ) -> Result<usize, Self::Error>;
15}
16
17pub trait Query<T>: Transaction {
18    fn query(&mut self, query: &str) -> Result<T, Self::Error>;
19}
20
21pub fn migrate<T: Transaction>(
22    transaction: &mut T,
23    migrations: Vec<Migration>,
24    target: Target,
25    migration_table_name: &str,
26    batched: bool,
27) -> Result<Report, Error> {
28    let mut migration_batch = Vec::new();
29    let mut applied_migrations = Vec::new();
30
31    for mut migration in migrations.into_iter() {
32        if let Target::Version(input_target) | Target::FakeVersion(input_target) = target {
33            if input_target < migration.version() {
34                log::info!(
35                    "stopping at migration: {}, due to user option",
36                    input_target
37                );
38                break;
39            }
40        }
41
42        migration.set_applied();
43        let insert_migration = insert_migration_query(&migration, migration_table_name);
44        let migration_sql = migration.sql().expect("sql must be Some!").to_string();
45
46        // If Target is Fake, we only update schema migrations table
47        if !matches!(target, Target::Fake | Target::FakeVersion(_)) {
48            applied_migrations.push(migration);
49            migration_batch.push(migration_sql);
50        }
51        migration_batch.push(insert_migration);
52    }
53
54    match (target, batched) {
55        (Target::Fake | Target::FakeVersion(_), _) => {
56            log::info!("not going to apply any migration as fake flag is enabled.");
57        }
58        (Target::Latest | Target::Version(_), true) => {
59            log::info!(
60                "going to batch apply {} migrations in single transaction.",
61                applied_migrations.len()
62            );
63        }
64        (Target::Latest | Target::Version(_), false) => {
65            log::info!(
66                "going to apply {} migrations in multiple transactions.",
67                applied_migrations.len(),
68            );
69        }
70    };
71
72    let refs = migration_batch.iter().map(AsRef::as_ref);
73
74    if batched {
75        let migrations_display = applied_migrations
76            .iter()
77            .map(ToString::to_string)
78            .collect::<Vec<String>>()
79            .join("\n");
80        log::info!("going to apply batch migrations in single transaction:\n{migrations_display}");
81        transaction
82            .execute(refs)
83            .migration_err("error applying migrations", None)?;
84    } else {
85        for (i, update) in refs.enumerate() {
86            // first iteration is pair so we know the following even in the iteration index
87            // marks the previous (pair) migration as completed.
88            let applying_migration = i % 2 == 0;
89            let current_migration = &applied_migrations[i / 2];
90            if applying_migration {
91                log::info!("applying migration: {current_migration} ...");
92            } else {
93                // Writing the migration state to the db.
94                log::debug!("applied migration:  {current_migration} writing state to db.");
95            }
96            transaction
97                .execute([update].into_iter())
98                .migration_err("error applying update", Some(&applied_migrations[0..i / 2]))?;
99        }
100    }
101
102    Ok(Report::new(applied_migrations))
103}
104
105pub trait Migrate: Query<Vec<Migration>>
106where
107    Self: Sized,
108{
109    // Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table
110    fn assert_migrations_table_query(migration_table_name: &str) -> String {
111        super::assert_migrations_table_query(migration_table_name)
112    }
113
114    fn get_last_applied_migration_query(migration_table_name: &str) -> String {
115        GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
116    }
117
118    fn get_applied_migrations_query(migration_table_name: &str) -> String {
119        GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name)
120    }
121
122    fn assert_migrations_table(&mut self, migration_table_name: &str) -> Result<usize, Error> {
123        // Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table,
124        // though on this case it's just to be consistent with the async trait `AsyncMigrate`
125        self.execute(
126            [Self::assert_migrations_table_query(migration_table_name).as_ref()].into_iter(),
127        )
128        .migration_err("error asserting migrations table", None)
129    }
130
131    fn get_last_applied_migration(
132        &mut self,
133        migration_table_name: &str,
134    ) -> Result<Option<Migration>, Error> {
135        let mut migrations = self
136            .query(Self::get_last_applied_migration_query(migration_table_name).as_str())
137            .migration_err("error getting last applied migration", None)?;
138
139        Ok(migrations.pop())
140    }
141
142    fn get_applied_migrations(
143        &mut self,
144        migration_table_name: &str,
145    ) -> Result<Vec<Migration>, Error> {
146        let migrations = self
147            .query(Self::get_applied_migrations_query(migration_table_name).as_str())
148            .migration_err("error getting applied migrations", None)?;
149
150        Ok(migrations)
151    }
152
153    fn get_unapplied_migrations(
154        &mut self,
155        migrations: &[Migration],
156        abort_divergent: bool,
157        abort_missing: bool,
158        migration_table_name: &str,
159    ) -> Result<Vec<Migration>, Error> {
160        self.assert_migrations_table(migration_table_name)?;
161
162        let applied_migrations = self.get_applied_migrations(migration_table_name)?;
163
164        let migrations = verify_migrations(
165            applied_migrations,
166            migrations.to_vec(),
167            abort_divergent,
168            abort_missing,
169        )?;
170
171        if migrations.is_empty() {
172            log::info!("no migrations to apply");
173        }
174
175        Ok(migrations)
176    }
177
178    fn migrate(
179        &mut self,
180        migrations: &[Migration],
181        abort_divergent: bool,
182        abort_missing: bool,
183        grouped: bool,
184        target: Target,
185        migration_table_name: &str,
186    ) -> Result<Report, Error> {
187        let migrations = self.get_unapplied_migrations(
188            migrations,
189            abort_divergent,
190            abort_missing,
191            migration_table_name,
192        )?;
193
194        if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) {
195            migrate(self, migrations, target, migration_table_name, true)
196        } else {
197            migrate(self, migrations, target, migration_table_name, false)
198        }
199    }
200}