1use std::collections::BTreeSet;
25
26use sqlx::{AnyPool, Row};
27
28use crate::error::DataError;
29
30#[derive(Debug, Clone)]
32pub struct Migration {
33 pub version: i64,
35 pub name: &'static str,
37 pub sql: &'static str,
39}
40
41impl Migration {
42 #[must_use]
44 pub const fn new(version: i64, name: &'static str, sql: &'static str) -> Self {
45 Self { version, name, sql }
46 }
47}
48
49#[derive(Debug, Clone)]
52pub struct Migrator {
53 migrations: Vec<Migration>,
54}
55
56impl Migrator {
57 pub fn new(migrations: impl IntoIterator<Item = Migration>) -> Result<Self, DataError> {
62 let mut migrations: Vec<Migration> = migrations.into_iter().collect();
63 migrations.sort_by_key(|m| m.version);
64
65 for pair in migrations.windows(2) {
66 if let [a, b] = pair
67 && a.version == b.version
68 {
69 return Err(DataError::Migration(format!(
70 "duplicate migration version {}",
71 a.version
72 )));
73 }
74 }
75 Ok(Self { migrations })
76 }
77
78 pub async fn run(&self, pool: &AnyPool) -> Result<u64, DataError> {
87 ensure_table(pool).await?;
88 let applied: BTreeSet<i64> = fetch_versions(pool).await?.into_iter().collect();
89
90 let mut count = 0u64;
91 for migration in &self.migrations {
92 if applied.contains(&migration.version) {
93 continue;
94 }
95 tracing::info!(
96 version = migration.version,
97 name = migration.name,
98 "applying migration"
99 );
100
101 let mut tx = pool.begin().await?;
102 sqlx::raw_sql(migration.sql).execute(&mut *tx).await?;
103 let record = format!(
104 "INSERT INTO _klauthed_migrations (version, name) VALUES ({}, '{}')",
105 migration.version,
106 migration.name.replace('\'', "''"),
107 );
108 sqlx::raw_sql(sqlx::AssertSqlSafe(record)).execute(&mut *tx).await?;
111 tx.commit().await?;
112 count += 1;
113 }
114 Ok(count)
115 }
116
117 pub async fn applied(&self, pool: &AnyPool) -> Result<Vec<i64>, DataError> {
122 ensure_table(pool).await?;
123 fetch_versions(pool).await
124 }
125}
126
127async fn ensure_table(pool: &AnyPool) -> Result<(), DataError> {
129 sqlx::raw_sql(
130 "CREATE TABLE IF NOT EXISTS _klauthed_migrations (\
131 version BIGINT PRIMARY KEY, \
132 name TEXT NOT NULL, \
133 applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP)",
134 )
135 .execute(pool)
136 .await?;
137 Ok(())
138}
139
140async fn fetch_versions(pool: &AnyPool) -> Result<Vec<i64>, DataError> {
142 let rows = sqlx::query("SELECT version FROM _klauthed_migrations ORDER BY version")
143 .fetch_all(pool)
144 .await?;
145 let mut versions = Vec::with_capacity(rows.len());
146 for row in &rows {
147 versions.push(row.try_get::<i64, _>("version")?);
148 }
149 Ok(versions)
150}
151
152#[cfg(all(test, feature = "sqlite"))]
153mod tests {
154 use super::*;
155
156 async fn memory_pool() -> AnyPool {
157 sqlx::any::install_default_drivers();
158 sqlx::any::AnyPoolOptions::new()
161 .max_connections(1)
162 .connect("sqlite::memory:")
163 .await
164 .unwrap()
165 }
166
167 #[tokio::test]
168 async fn applies_pending_then_is_idempotent() {
169 let pool = memory_pool().await;
170 let migrator = Migrator::new([
171 Migration::new(1, "create_users", "CREATE TABLE users (id BIGINT PRIMARY KEY)"),
172 Migration::new(2, "add_email", "ALTER TABLE users ADD COLUMN email TEXT"),
173 ])
174 .unwrap();
175
176 assert_eq!(migrator.run(&pool).await.unwrap(), 2);
177 assert_eq!(migrator.applied(&pool).await.unwrap(), vec![1, 2]);
178
179 assert_eq!(migrator.run(&pool).await.unwrap(), 0);
181
182 sqlx::raw_sql("INSERT INTO users (id, email) VALUES (1, 'a@b.c')")
184 .execute(&pool)
185 .await
186 .unwrap();
187 }
188
189 #[tokio::test]
190 async fn rejects_duplicate_versions() {
191 let result =
192 Migrator::new([Migration::new(1, "a", "SELECT 1"), Migration::new(1, "b", "SELECT 1")]);
193 assert!(matches!(result, Err(DataError::Migration(_))));
194 }
195}