Skip to main content

klauthed_data/
migrate.rs

1//! Embedded, versioned schema migrations over a relational pool
2//! (`feature = "sql"`).
3//!
4//! Define forward [`Migration`]s and run them with a [`Migrator`]; applied
5//! versions are tracked in a `_klauthed_migrations` table so each runs exactly
6//! once and re-running is a no-op. Works over the driver-agnostic
7//! [`sqlx::AnyPool`], so the same runner serves Postgres / MySQL / SQLite — the
8//! migration SQL is yours to keep portable to whichever you target.
9//!
10//! ```no_run
11//! use klauthed_data::migrate::{Migration, Migrator};
12//!
13//! # async fn run(pool: &klauthed_data::AnyPool) -> Result<(), klauthed_data::DataError> {
14//! let migrator = Migrator::new([
15//!     Migration::new(1, "create_users", "CREATE TABLE users (id BIGINT PRIMARY KEY)"),
16//!     Migration::new(2, "add_email", "ALTER TABLE users ADD COLUMN email TEXT"),
17//! ])?;
18//! let applied = migrator.run(pool).await?;
19//! println!("applied {applied} migration(s)");
20//! # Ok(())
21//! # }
22//! ```
23
24use std::collections::BTreeSet;
25
26use sqlx::{AnyPool, Row};
27
28use crate::error::DataError;
29
30/// A single forward ("up") migration.
31#[derive(Debug, Clone)]
32pub struct Migration {
33    /// Monotonic version; migrations apply in ascending order and record once.
34    pub version: i64,
35    /// Human-readable name, stored alongside the version and logged.
36    pub name: &'static str,
37    /// The SQL to run. May contain multiple statements.
38    pub sql: &'static str,
39}
40
41impl Migration {
42    /// Construct a migration.
43    #[must_use]
44    pub const fn new(version: i64, name: &'static str, sql: &'static str) -> Self {
45        Self { version, name, sql }
46    }
47}
48
49/// Runs ordered [`Migration`]s against an [`AnyPool`], tracking applied versions
50/// so each runs exactly once.
51#[derive(Debug, Clone)]
52pub struct Migrator {
53    migrations: Vec<Migration>,
54}
55
56impl Migrator {
57    /// Build a migrator from a set of migrations (sorted by version).
58    ///
59    /// # Errors
60    /// Returns [`DataError::Migration`] if two migrations share a version.
61    pub fn new(migrations: impl IntoIterator<Item = Migration>) -> Result<Self, DataError> {
62        let mut migrations: Vec<Migration> = migrations.into_iter().collect();
63        migrations.sort_by_key(|m| m.version);
64
65        for pair in migrations.windows(2) {
66            if let [a, b] = pair
67                && a.version == b.version
68            {
69                return Err(DataError::Migration(format!(
70                    "duplicate migration version {}",
71                    a.version
72                )));
73            }
74        }
75        Ok(Self { migrations })
76    }
77
78    /// Apply every pending migration in version order; returns the number
79    /// applied. Re-running after a successful run applies nothing.
80    ///
81    /// Each migration runs in its own transaction, so a failure leaves earlier
82    /// migrations committed and the failing one rolled back.
83    ///
84    /// # Errors
85    /// Returns [`DataError`] on any SQL failure.
86    pub async fn run(&self, pool: &AnyPool) -> Result<u64, DataError> {
87        ensure_table(pool).await?;
88        let applied: BTreeSet<i64> = fetch_versions(pool).await?.into_iter().collect();
89
90        let mut count = 0u64;
91        for migration in &self.migrations {
92            if applied.contains(&migration.version) {
93                continue;
94            }
95            tracing::info!(
96                version = migration.version,
97                name = migration.name,
98                "applying migration"
99            );
100
101            let mut tx = pool.begin().await?;
102            sqlx::raw_sql(migration.sql).execute(&mut *tx).await?;
103            let record = format!(
104                "INSERT INTO _klauthed_migrations (version, name) VALUES ({}, '{}')",
105                migration.version,
106                migration.name.replace('\'', "''"),
107            );
108            // Audited: `version` is an integer and `name` is a `'static` literal
109            // (single quotes escaped), so this inline write is injection-safe.
110            sqlx::raw_sql(sqlx::AssertSqlSafe(record)).execute(&mut *tx).await?;
111            tx.commit().await?;
112            count += 1;
113        }
114        Ok(count)
115    }
116
117    /// The versions already recorded as applied, ascending.
118    ///
119    /// # Errors
120    /// Returns [`DataError`] on a SQL failure.
121    pub async fn applied(&self, pool: &AnyPool) -> Result<Vec<i64>, DataError> {
122        ensure_table(pool).await?;
123        fetch_versions(pool).await
124    }
125}
126
127/// Create the migration-tracking table if it doesn't exist (portable DDL).
128async fn ensure_table(pool: &AnyPool) -> Result<(), DataError> {
129    sqlx::raw_sql(
130        "CREATE TABLE IF NOT EXISTS _klauthed_migrations (\
131         version BIGINT PRIMARY KEY, \
132         name TEXT NOT NULL, \
133         applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP)",
134    )
135    .execute(pool)
136    .await?;
137    Ok(())
138}
139
140/// Read the recorded versions in ascending order.
141async fn fetch_versions(pool: &AnyPool) -> Result<Vec<i64>, DataError> {
142    let rows = sqlx::query("SELECT version FROM _klauthed_migrations ORDER BY version")
143        .fetch_all(pool)
144        .await?;
145    let mut versions = Vec::with_capacity(rows.len());
146    for row in &rows {
147        versions.push(row.try_get::<i64, _>("version")?);
148    }
149    Ok(versions)
150}
151
152#[cfg(all(test, feature = "sqlite"))]
153mod tests {
154    use super::*;
155
156    async fn memory_pool() -> AnyPool {
157        sqlx::any::install_default_drivers();
158        // `sqlite::memory:` is private per connection, so pin the pool to one
159        // connection — otherwise each query could hit a different empty database.
160        sqlx::any::AnyPoolOptions::new()
161            .max_connections(1)
162            .connect("sqlite::memory:")
163            .await
164            .unwrap()
165    }
166
167    #[tokio::test]
168    async fn applies_pending_then_is_idempotent() {
169        let pool = memory_pool().await;
170        let migrator = Migrator::new([
171            Migration::new(1, "create_users", "CREATE TABLE users (id BIGINT PRIMARY KEY)"),
172            Migration::new(2, "add_email", "ALTER TABLE users ADD COLUMN email TEXT"),
173        ])
174        .unwrap();
175
176        assert_eq!(migrator.run(&pool).await.unwrap(), 2);
177        assert_eq!(migrator.applied(&pool).await.unwrap(), vec![1, 2]);
178
179        // A second run applies nothing.
180        assert_eq!(migrator.run(&pool).await.unwrap(), 0);
181
182        // The migrated schema is usable.
183        sqlx::raw_sql("INSERT INTO users (id, email) VALUES (1, 'a@b.c')")
184            .execute(&pool)
185            .await
186            .unwrap();
187    }
188
189    #[tokio::test]
190    async fn rejects_duplicate_versions() {
191        let result =
192            Migrator::new([Migration::new(1, "a", "SELECT 1"), Migration::new(1, "b", "SELECT 1")]);
193        assert!(matches!(result, Err(DataError::Migration(_))));
194    }
195}