Skip to main content

runledger_postgres/
migrations.rs

1use std::collections::HashMap;
2use std::fmt;
3
4use sqlx::migrate::{AppliedMigration, Migrate, MigrateError, Migrator};
5
6use crate::DbPool;
7
8pub static MIGRATOR: Migrator = sqlx::migrate!("./migrations");
9
10#[derive(Debug)]
11pub enum SchemaCompatibilityError {
12    Query(sqlx::Error),
13    MissingMigrationHistory {
14        required_first_migration_version: i64,
15    },
16    Incompatible(MigrateError),
17}
18
19impl fmt::Display for SchemaCompatibilityError {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        match self {
22            Self::Query(error) => write!(
23                f,
24                "Runledger schema compatibility check could not query PostgreSQL state: {error}"
25            ),
26            Self::MissingMigrationHistory {
27                required_first_migration_version,
28            } => write!(
29                f,
30                "Runledger schema compatibility check requires the _sqlx_migrations table; apply or record Runledger migrations first (expected migration history starting at version {required_first_migration_version})"
31            ),
32            Self::Incompatible(error) => write!(f, "{error}"),
33        }
34    }
35}
36
37impl std::error::Error for SchemaCompatibilityError {
38    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
39        match self {
40            Self::Query(error) => Some(error),
41            Self::MissingMigrationHistory { .. } => None,
42            Self::Incompatible(error) => Some(error),
43        }
44    }
45}
46
47impl From<MigrateError> for SchemaCompatibilityError {
48    fn from(error: MigrateError) -> Self {
49        Self::Incompatible(error)
50    }
51}
52
53impl From<sqlx::Error> for SchemaCompatibilityError {
54    fn from(error: sqlx::Error) -> Self {
55        Self::Query(error)
56    }
57}
58
59/// Apply the bundled Runledger schema migrations to a PostgreSQL pool.
60pub async fn migrate(pool: &DbPool) -> Result<(), MigrateError> {
61    let mut conn = pool.acquire().await?;
62
63    if MIGRATOR.locking {
64        (*conn).lock().await?;
65    }
66
67    let result = run_migrations_with_filtered_history(&mut conn).await;
68    let unlock_result = if MIGRATOR.locking {
69        (*conn).unlock().await
70    } else {
71        Ok(())
72    };
73
74    match (result, unlock_result) {
75        (Err(error), _) => Err(error),
76        (Ok(()), Err(error)) => Err(error),
77        (Ok(()), Ok(())) => Ok(()),
78    }
79}
80
81/// Validate that the target database's SQLx migration history matches the
82/// bundled Runledger migrations.
83///
84/// Unlike [`migrate`], this does not apply pending migrations. It is intended
85/// for deployments that manage DDL outside the application process but still
86/// want a startup guardrail. This check is read-only, but it relies on the
87/// `_sqlx_migrations` history table being present and up to date. When present,
88/// it also uses Runledger's own `runledger_migration_history` table to detect
89/// migrations applied by newer Runledger releases.
90pub async fn ensure_schema_compatible(pool: &DbPool) -> Result<(), SchemaCompatibilityError> {
91    let mut conn = pool.acquire().await?;
92
93    if !has_migrations_table(&mut conn).await? {
94        return Err(SchemaCompatibilityError::MissingMigrationHistory {
95            required_first_migration_version: first_up_migration_version(),
96        });
97    }
98
99    let expected_migrations = expected_runledger_migrations();
100    let history = list_migration_history(&mut conn).await?;
101
102    if let Some(version) = first_conflicting_runledger_version(&history, &expected_migrations) {
103        return Err(SchemaCompatibilityError::Incompatible(
104            MigrateError::VersionMismatch(version),
105        ));
106    }
107
108    if let Some(version) = first_dirty_runledger_version(&history, &expected_migrations) {
109        return Err(SchemaCompatibilityError::Incompatible(MigrateError::Dirty(
110            version,
111        )));
112    }
113
114    if has_runledger_migration_history_table(&mut conn).await? {
115        let recorded_versions = list_recorded_runledger_migrations(&mut conn).await?;
116        if let Some(version) =
117            first_missing_runledger_version(&recorded_versions, &expected_migrations)
118        {
119            return Err(SchemaCompatibilityError::Incompatible(
120                MigrateError::VersionMissing(version),
121            ));
122        }
123    }
124
125    let applied = applied_runledger_migrations(&history, &expected_migrations);
126    let applied_by_version: HashMap<_, _> = applied
127        .iter()
128        .map(|applied_migration| (applied_migration.version, applied_migration))
129        .collect();
130    let latest_applied_version = applied.iter().map(|migration| migration.version).max();
131
132    for migration in MIGRATOR
133        .iter()
134        .filter(|migration| migration.migration_type.is_up_migration())
135    {
136        match applied_by_version.get(&migration.version) {
137            Some(applied_migration) => {
138                validate_checksum(migration.version, applied_migration, migration)
139                    .map_err(SchemaCompatibilityError::from)?
140            }
141            None => {
142                return Err(SchemaCompatibilityError::Incompatible(
143                    MigrateError::VersionTooNew(
144                        migration.version,
145                        latest_applied_version.unwrap_or_default(),
146                    ),
147                ));
148            }
149        }
150    }
151
152    Ok(())
153}
154
155async fn has_migrations_table(
156    conn: &mut sqlx::pool::PoolConnection<sqlx::Postgres>,
157) -> Result<bool, sqlx::Error> {
158    sqlx::query_scalar::<_, bool>("SELECT to_regclass('_sqlx_migrations') IS NOT NULL")
159        .fetch_one(&mut **conn)
160        .await
161}
162
163async fn has_runledger_migration_history_table(
164    conn: &mut sqlx::pool::PoolConnection<sqlx::Postgres>,
165) -> Result<bool, sqlx::Error> {
166    sqlx::query_scalar::<_, bool>("SELECT to_regclass('runledger_migration_history') IS NOT NULL")
167        .fetch_one(&mut **conn)
168        .await
169}
170
171async fn list_migration_history(
172    conn: &mut sqlx::pool::PoolConnection<sqlx::Postgres>,
173) -> Result<Vec<MigrationHistoryRow>, sqlx::Error> {
174    sqlx::query_as::<_, MigrationHistoryRow>(
175        "SELECT version, checksum, success
176         FROM _sqlx_migrations
177         ORDER BY version",
178    )
179    .fetch_all(&mut **conn)
180    .await
181}
182
183async fn list_recorded_runledger_migrations(
184    conn: &mut sqlx::pool::PoolConnection<sqlx::Postgres>,
185) -> Result<Vec<i64>, sqlx::Error> {
186    sqlx::query_scalar::<_, i64>(
187        "SELECT version
188         FROM runledger_migration_history
189         ORDER BY version",
190    )
191    .fetch_all(&mut **conn)
192    .await
193}
194
195fn first_up_migration_version() -> i64 {
196    MIGRATOR
197        .iter()
198        .find(|migration| migration.migration_type.is_up_migration())
199        .map(|migration| migration.version)
200        .unwrap_or_default()
201}
202
203fn expected_runledger_migrations() -> HashMap<i64, &'static sqlx::migrate::Migration> {
204    MIGRATOR
205        .iter()
206        .filter(|migration| migration.migration_type.is_up_migration())
207        .map(|migration| (migration.version, migration))
208        .collect()
209}
210
211fn first_conflicting_runledger_version(
212    history: &[MigrationHistoryRow],
213    expected_migrations: &HashMap<i64, &'static sqlx::migrate::Migration>,
214) -> Option<i64> {
215    history.iter().find_map(|row| {
216        expected_migrations
217            .get(&row.version)
218            .filter(|migration| row.checksum.as_slice() != migration.checksum.as_ref())
219            .map(|_| row.version)
220    })
221}
222
223fn first_dirty_runledger_version(
224    history: &[MigrationHistoryRow],
225    expected_migrations: &HashMap<i64, &'static sqlx::migrate::Migration>,
226) -> Option<i64> {
227    history.iter().find_map(|row| {
228        (!row.success)
229            .then(|| {
230                expected_migrations
231                    .get(&row.version)
232                    .filter(|migration| row.checksum.as_slice() == migration.checksum.as_ref())
233                    .map(|_| row.version)
234            })
235            .flatten()
236    })
237}
238
239fn first_missing_runledger_version(
240    recorded_versions: &[i64],
241    expected_migrations: &HashMap<i64, &'static sqlx::migrate::Migration>,
242) -> Option<i64> {
243    recorded_versions
244        .iter()
245        .copied()
246        .find(|version| !expected_migrations.contains_key(version))
247}
248
249fn applied_runledger_migrations(
250    history: &[MigrationHistoryRow],
251    expected_migrations: &HashMap<i64, &'static sqlx::migrate::Migration>,
252) -> Vec<AppliedMigration> {
253    history
254        .iter()
255        .filter(|row| row.success)
256        .filter(|row| {
257            expected_migrations
258                .get(&row.version)
259                .is_some_and(|migration| row.checksum.as_slice() == migration.checksum.as_ref())
260        })
261        .map(|row| AppliedMigration {
262            version: row.version,
263            checksum: row.checksum.clone().into(),
264        })
265        .collect()
266}
267
268async fn run_migrations_with_filtered_history(
269    conn: &mut sqlx::pool::PoolConnection<sqlx::Postgres>,
270) -> Result<(), MigrateError> {
271    (**conn).ensure_migrations_table().await?;
272
273    let expected_migrations = expected_runledger_migrations();
274    let history = list_migration_history(conn).await?;
275
276    if let Some(version) = first_conflicting_runledger_version(&history, &expected_migrations) {
277        return Err(MigrateError::VersionMismatch(version));
278    }
279
280    if let Some(version) = first_dirty_runledger_version(&history, &expected_migrations) {
281        return Err(MigrateError::Dirty(version));
282    }
283
284    if has_runledger_migration_history_table(conn).await? {
285        let recorded_versions = list_recorded_runledger_migrations(conn).await?;
286        if let Some(version) =
287            first_missing_runledger_version(&recorded_versions, &expected_migrations)
288        {
289            return Err(MigrateError::VersionMissing(version));
290        }
291    }
292
293    let applied = applied_runledger_migrations(&history, &expected_migrations);
294    let applied_by_version: HashMap<_, _> = applied
295        .into_iter()
296        .map(|migration| (migration.version, migration))
297        .collect();
298
299    for migration in MIGRATOR
300        .iter()
301        .filter(|migration| migration.migration_type.is_up_migration())
302    {
303        match applied_by_version.get(&migration.version) {
304            Some(applied_migration) => {
305                validate_checksum(migration.version, applied_migration, migration)?
306            }
307            None => {
308                (**conn).apply(migration).await?;
309            }
310        }
311    }
312
313    Ok(())
314}
315
316#[derive(sqlx::FromRow)]
317struct MigrationHistoryRow {
318    version: i64,
319    checksum: Vec<u8>,
320    success: bool,
321}
322
323fn validate_checksum(
324    version: i64,
325    applied_migration: &AppliedMigration,
326    expected_migration: &sqlx::migrate::Migration,
327) -> Result<(), MigrateError> {
328    if applied_migration.checksum != expected_migration.checksum {
329        return Err(MigrateError::VersionMismatch(version));
330    }
331
332    Ok(())
333}