1use std::collections::HashMap;
2use std::fmt;
3
4use sqlx::migrate::{AppliedMigration, Migrate, MigrateError, Migrator};
5
6use crate::DbPool;
7
8pub static MIGRATOR: Migrator = sqlx::migrate!("./migrations");
9
10#[derive(Debug)]
11pub enum SchemaCompatibilityError {
12 Query(sqlx::Error),
13 MissingMigrationHistory {
14 required_first_migration_version: i64,
15 },
16 Incompatible(MigrateError),
17}
18
19impl fmt::Display for SchemaCompatibilityError {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 match self {
22 Self::Query(error) => write!(
23 f,
24 "Runledger schema compatibility check could not query PostgreSQL state: {error}"
25 ),
26 Self::MissingMigrationHistory {
27 required_first_migration_version,
28 } => write!(
29 f,
30 "Runledger schema compatibility check requires the _sqlx_migrations table; apply or record Runledger migrations first (expected migration history starting at version {required_first_migration_version})"
31 ),
32 Self::Incompatible(error) => write!(f, "{error}"),
33 }
34 }
35}
36
37impl std::error::Error for SchemaCompatibilityError {
38 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
39 match self {
40 Self::Query(error) => Some(error),
41 Self::MissingMigrationHistory { .. } => None,
42 Self::Incompatible(error) => Some(error),
43 }
44 }
45}
46
47impl From<MigrateError> for SchemaCompatibilityError {
48 fn from(error: MigrateError) -> Self {
49 Self::Incompatible(error)
50 }
51}
52
53impl From<sqlx::Error> for SchemaCompatibilityError {
54 fn from(error: sqlx::Error) -> Self {
55 Self::Query(error)
56 }
57}
58
59pub async fn migrate(pool: &DbPool) -> Result<(), MigrateError> {
61 let mut conn = pool.acquire().await?;
62
63 if MIGRATOR.locking {
64 (*conn).lock().await?;
65 }
66
67 let result = run_migrations_with_filtered_history(&mut conn).await;
68 let unlock_result = if MIGRATOR.locking {
69 (*conn).unlock().await
70 } else {
71 Ok(())
72 };
73
74 match (result, unlock_result) {
75 (Err(error), _) => Err(error),
76 (Ok(()), Err(error)) => Err(error),
77 (Ok(()), Ok(())) => Ok(()),
78 }
79}
80
81pub async fn ensure_schema_compatible(pool: &DbPool) -> Result<(), SchemaCompatibilityError> {
91 let mut conn = pool.acquire().await?;
92
93 if !has_migrations_table(&mut conn).await? {
94 return Err(SchemaCompatibilityError::MissingMigrationHistory {
95 required_first_migration_version: first_up_migration_version(),
96 });
97 }
98
99 let expected_migrations = expected_runledger_migrations();
100 let history = list_migration_history(&mut conn).await?;
101
102 if let Some(version) = first_conflicting_runledger_version(&history, &expected_migrations) {
103 return Err(SchemaCompatibilityError::Incompatible(
104 MigrateError::VersionMismatch(version),
105 ));
106 }
107
108 if let Some(version) = first_dirty_runledger_version(&history, &expected_migrations) {
109 return Err(SchemaCompatibilityError::Incompatible(MigrateError::Dirty(
110 version,
111 )));
112 }
113
114 if has_runledger_migration_history_table(&mut conn).await? {
115 let recorded_versions = list_recorded_runledger_migrations(&mut conn).await?;
116 if let Some(version) =
117 first_missing_runledger_version(&recorded_versions, &expected_migrations)
118 {
119 return Err(SchemaCompatibilityError::Incompatible(
120 MigrateError::VersionMissing(version),
121 ));
122 }
123 }
124
125 let applied = applied_runledger_migrations(&history, &expected_migrations);
126 let applied_by_version: HashMap<_, _> = applied
127 .iter()
128 .map(|applied_migration| (applied_migration.version, applied_migration))
129 .collect();
130 let latest_applied_version = applied.iter().map(|migration| migration.version).max();
131
132 for migration in MIGRATOR
133 .iter()
134 .filter(|migration| migration.migration_type.is_up_migration())
135 {
136 match applied_by_version.get(&migration.version) {
137 Some(applied_migration) => {
138 validate_checksum(migration.version, applied_migration, migration)
139 .map_err(SchemaCompatibilityError::from)?
140 }
141 None => {
142 return Err(SchemaCompatibilityError::Incompatible(
143 MigrateError::VersionTooNew(
144 migration.version,
145 latest_applied_version.unwrap_or_default(),
146 ),
147 ));
148 }
149 }
150 }
151
152 Ok(())
153}
154
155async fn has_migrations_table(
156 conn: &mut sqlx::pool::PoolConnection<sqlx::Postgres>,
157) -> Result<bool, sqlx::Error> {
158 sqlx::query_scalar::<_, bool>("SELECT to_regclass('_sqlx_migrations') IS NOT NULL")
159 .fetch_one(&mut **conn)
160 .await
161}
162
163async fn has_runledger_migration_history_table(
164 conn: &mut sqlx::pool::PoolConnection<sqlx::Postgres>,
165) -> Result<bool, sqlx::Error> {
166 sqlx::query_scalar::<_, bool>("SELECT to_regclass('runledger_migration_history') IS NOT NULL")
167 .fetch_one(&mut **conn)
168 .await
169}
170
171async fn list_migration_history(
172 conn: &mut sqlx::pool::PoolConnection<sqlx::Postgres>,
173) -> Result<Vec<MigrationHistoryRow>, sqlx::Error> {
174 sqlx::query_as::<_, MigrationHistoryRow>(
175 "SELECT version, checksum, success
176 FROM _sqlx_migrations
177 ORDER BY version",
178 )
179 .fetch_all(&mut **conn)
180 .await
181}
182
183async fn list_recorded_runledger_migrations(
184 conn: &mut sqlx::pool::PoolConnection<sqlx::Postgres>,
185) -> Result<Vec<i64>, sqlx::Error> {
186 sqlx::query_scalar::<_, i64>(
187 "SELECT version
188 FROM runledger_migration_history
189 ORDER BY version",
190 )
191 .fetch_all(&mut **conn)
192 .await
193}
194
195fn first_up_migration_version() -> i64 {
196 MIGRATOR
197 .iter()
198 .find(|migration| migration.migration_type.is_up_migration())
199 .map(|migration| migration.version)
200 .unwrap_or_default()
201}
202
203fn expected_runledger_migrations() -> HashMap<i64, &'static sqlx::migrate::Migration> {
204 MIGRATOR
205 .iter()
206 .filter(|migration| migration.migration_type.is_up_migration())
207 .map(|migration| (migration.version, migration))
208 .collect()
209}
210
211fn first_conflicting_runledger_version(
212 history: &[MigrationHistoryRow],
213 expected_migrations: &HashMap<i64, &'static sqlx::migrate::Migration>,
214) -> Option<i64> {
215 history.iter().find_map(|row| {
216 expected_migrations
217 .get(&row.version)
218 .filter(|migration| row.checksum.as_slice() != migration.checksum.as_ref())
219 .map(|_| row.version)
220 })
221}
222
223fn first_dirty_runledger_version(
224 history: &[MigrationHistoryRow],
225 expected_migrations: &HashMap<i64, &'static sqlx::migrate::Migration>,
226) -> Option<i64> {
227 history.iter().find_map(|row| {
228 (!row.success)
229 .then(|| {
230 expected_migrations
231 .get(&row.version)
232 .filter(|migration| row.checksum.as_slice() == migration.checksum.as_ref())
233 .map(|_| row.version)
234 })
235 .flatten()
236 })
237}
238
239fn first_missing_runledger_version(
240 recorded_versions: &[i64],
241 expected_migrations: &HashMap<i64, &'static sqlx::migrate::Migration>,
242) -> Option<i64> {
243 recorded_versions
244 .iter()
245 .copied()
246 .find(|version| !expected_migrations.contains_key(version))
247}
248
249fn applied_runledger_migrations(
250 history: &[MigrationHistoryRow],
251 expected_migrations: &HashMap<i64, &'static sqlx::migrate::Migration>,
252) -> Vec<AppliedMigration> {
253 history
254 .iter()
255 .filter(|row| row.success)
256 .filter(|row| {
257 expected_migrations
258 .get(&row.version)
259 .is_some_and(|migration| row.checksum.as_slice() == migration.checksum.as_ref())
260 })
261 .map(|row| AppliedMigration {
262 version: row.version,
263 checksum: row.checksum.clone().into(),
264 })
265 .collect()
266}
267
268async fn run_migrations_with_filtered_history(
269 conn: &mut sqlx::pool::PoolConnection<sqlx::Postgres>,
270) -> Result<(), MigrateError> {
271 (**conn).ensure_migrations_table().await?;
272
273 let expected_migrations = expected_runledger_migrations();
274 let history = list_migration_history(conn).await?;
275
276 if let Some(version) = first_conflicting_runledger_version(&history, &expected_migrations) {
277 return Err(MigrateError::VersionMismatch(version));
278 }
279
280 if let Some(version) = first_dirty_runledger_version(&history, &expected_migrations) {
281 return Err(MigrateError::Dirty(version));
282 }
283
284 if has_runledger_migration_history_table(conn).await? {
285 let recorded_versions = list_recorded_runledger_migrations(conn).await?;
286 if let Some(version) =
287 first_missing_runledger_version(&recorded_versions, &expected_migrations)
288 {
289 return Err(MigrateError::VersionMissing(version));
290 }
291 }
292
293 let applied = applied_runledger_migrations(&history, &expected_migrations);
294 let applied_by_version: HashMap<_, _> = applied
295 .into_iter()
296 .map(|migration| (migration.version, migration))
297 .collect();
298
299 for migration in MIGRATOR
300 .iter()
301 .filter(|migration| migration.migration_type.is_up_migration())
302 {
303 match applied_by_version.get(&migration.version) {
304 Some(applied_migration) => {
305 validate_checksum(migration.version, applied_migration, migration)?
306 }
307 None => {
308 (**conn).apply(migration).await?;
309 }
310 }
311 }
312
313 Ok(())
314}
315
316#[derive(sqlx::FromRow)]
317struct MigrationHistoryRow {
318 version: i64,
319 checksum: Vec<u8>,
320 success: bool,
321}
322
323fn validate_checksum(
324 version: i64,
325 applied_migration: &AppliedMigration,
326 expected_migration: &sqlx::migrate::Migration,
327) -> Result<(), MigrateError> {
328 if applied_migration.checksum != expected_migration.checksum {
329 return Err(MigrateError::VersionMismatch(version));
330 }
331
332 Ok(())
333}