1use std::error::Error;
2
3use sqlx::{Database, Executor, Pool};
4
5#[derive(Debug, Clone)]
6pub struct Migration {
7 pub version: String,
8 pub name: String,
9 pub up_sql: String,
10 pub down_sql: String,
11}
12
13#[derive(Debug, Clone, sqlx::FromRow)]
14struct AppliedMigration {
15 version: String,
16}
17
18pub struct Migrator<DB: Database> {
19 pool: Pool<DB>,
20}
21
22impl<DB: Database> Migrator<DB> {
23 pub fn new(pool: Pool<DB>) -> Self {
24 Self { pool }
25 }
26}
27
28#[cfg(feature = "sqlite")]
33impl Migrator<sqlx::Sqlite> {
34 pub async fn run(&self, migrations: Vec<Migration>) -> Result<(), Box<dyn Error>> {
35 let mut tx = self.pool.begin().await?;
36
37 sqlx::query(
39 "CREATE TABLE IF NOT EXISTS _premix_migrations (
40 version TEXT PRIMARY KEY,
41 name TEXT NOT NULL,
42 applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
43 )",
44 )
45 .execute(&mut *tx)
46 .await?;
47
48 let applied_versions: Vec<String> = sqlx::query_as::<_, AppliedMigration>(
50 "SELECT version FROM _premix_migrations ORDER BY version ASC",
51 )
52 .fetch_all(&mut *tx)
53 .await?
54 .into_iter()
55 .map(|m| m.version)
56 .collect();
57
58 for migration in migrations {
60 if !applied_versions.contains(&migration.version) {
61 println!(
62 "🚚 Applying migration: {} - {}",
63 migration.version, migration.name
64 );
65
66 tx.execute(migration.up_sql.as_str()).await?;
68
69 sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES (?, ?)")
71 .bind(&migration.version)
72 .bind(&migration.name)
73 .execute(&mut *tx)
74 .await?;
75 }
76 }
77
78 tx.commit().await?;
79 Ok(())
80 }
81}
82
83#[cfg(feature = "postgres")]
84impl Migrator<sqlx::Postgres> {
85 pub async fn run(&self, migrations: Vec<Migration>) -> Result<(), Box<dyn Error>> {
86 let mut tx = self.pool.begin().await?;
87
88 sqlx::query(
90 "CREATE TABLE IF NOT EXISTS _premix_migrations (
91 version TEXT PRIMARY KEY,
92 name TEXT NOT NULL,
93 applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
94 )",
95 )
96 .execute(&mut *tx)
97 .await?;
98
99 let applied_versions: Vec<String> = sqlx::query_as::<_, AppliedMigration>(
101 "SELECT version FROM _premix_migrations ORDER BY version ASC",
102 )
103 .fetch_all(&mut *tx)
104 .await?
105 .into_iter()
106 .map(|m| m.version)
107 .collect();
108
109 for migration in migrations {
111 if !applied_versions.contains(&migration.version) {
112 println!(
113 "🚚 Applying migration: {} - {}",
114 migration.version, migration.name
115 );
116
117 tx.execute(migration.up_sql.as_str()).await?;
125
126 sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES ($1, $2)")
128 .bind(&migration.version)
129 .bind(&migration.name)
130 .execute(&mut *tx)
131 .await?;
132 }
133 }
134
135 tx.commit().await?;
136 Ok(())
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 #[cfg(feature = "postgres")]
143 use std::time::{SystemTime, UNIX_EPOCH};
144
145 use sqlx::sqlite::SqlitePoolOptions;
146
147 use super::*;
148
149 #[cfg(feature = "postgres")]
150 async fn pg_pool_or_skip() -> Option<sqlx::Pool<sqlx::Postgres>> {
151 let db_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| {
152 "postgres://postgres:admin123@localhost:5432/premix_bench".to_string()
153 });
154 match sqlx::postgres::PgPoolOptions::new()
155 .max_connections(1)
156 .connect(&db_url)
157 .await
158 {
159 Ok(pool) => Some(pool),
160 Err(_) => None,
161 }
162 }
163
164 #[tokio::test]
165 async fn sqlite_migrator_applies_pending_once() {
166 let pool = SqlitePoolOptions::new()
167 .max_connections(1)
168 .connect("sqlite::memory:")
169 .await
170 .unwrap();
171 let migrator = Migrator::new(pool.clone());
172
173 let migrations = vec![Migration {
174 version: "20260101000000".to_string(),
175 name: "create_users".to_string(),
176 up_sql: "CREATE TABLE users (id INTEGER PRIMARY KEY);".to_string(),
177 down_sql: "DROP TABLE users;".to_string(),
178 }];
179
180 migrator.run(migrations.clone()).await.unwrap();
181
182 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
183 .fetch_one(&pool)
184 .await
185 .unwrap();
186 assert_eq!(count, 1);
187
188 migrator.run(migrations).await.unwrap();
189 let count_after: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
190 .fetch_one(&pool)
191 .await
192 .unwrap();
193 assert_eq!(count_after, 1);
194 }
195
196 #[tokio::test]
197 async fn sqlite_migrator_applies_multiple() {
198 let pool = SqlitePoolOptions::new()
199 .max_connections(1)
200 .connect("sqlite::memory:")
201 .await
202 .unwrap();
203 let migrator = Migrator::new(pool.clone());
204
205 let migrations = vec![
206 Migration {
207 version: "20260103000000".to_string(),
208 name: "create_a".to_string(),
209 up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
210 down_sql: "DROP TABLE a;".to_string(),
211 },
212 Migration {
213 version: "20260104000000".to_string(),
214 name: "create_b".to_string(),
215 up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
216 down_sql: "DROP TABLE b;".to_string(),
217 },
218 ];
219
220 migrator.run(migrations).await.unwrap();
221
222 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
223 .fetch_one(&pool)
224 .await
225 .unwrap();
226 assert_eq!(count, 2);
227 }
228
229 #[tokio::test]
230 async fn sqlite_migrator_rolls_back_on_error() {
231 let pool = SqlitePoolOptions::new()
232 .max_connections(1)
233 .connect("sqlite::memory:")
234 .await
235 .unwrap();
236 let migrator = Migrator::new(pool.clone());
237
238 let migrations = vec![Migration {
239 version: "20260102000000".to_string(),
240 name: "bad_sql".to_string(),
241 up_sql: "CREATE TABLE broken (id INTEGER PRIMARY KEY); INVALID SQL".to_string(),
242 down_sql: "DROP TABLE broken;".to_string(),
243 }];
244
245 let err = migrator.run(migrations).await.unwrap_err();
246 assert!(err.to_string().contains("syntax"));
247
248 let table: Option<String> = sqlx::query_scalar(
249 "SELECT name FROM sqlite_master WHERE type='table' AND name='_premix_migrations'",
250 )
251 .fetch_optional(&pool)
252 .await
253 .unwrap();
254 assert!(table.is_none());
255 }
256
257 #[cfg(feature = "postgres")]
258 #[tokio::test]
259 async fn postgres_migrator_applies_pending_once() {
260 let Some(pool) = pg_pool_or_skip().await else {
261 return;
262 };
263 let migrator = Migrator::new(pool.clone());
264
265 let suffix = SystemTime::now()
266 .duration_since(UNIX_EPOCH)
267 .unwrap()
268 .as_nanos();
269 let version = format!("20260101{:020}", suffix);
270 let table_name = format!("premix_mig_test_{}", suffix);
271 let migrations = vec![Migration {
272 version: version.clone(),
273 name: "create_test_table".to_string(),
274 up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_name),
275 down_sql: format!("DROP TABLE {};", table_name),
276 }];
277
278 migrator.run(migrations.clone()).await.unwrap();
279
280 let count: i64 =
281 sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
282 .bind(&version)
283 .fetch_one(&pool)
284 .await
285 .unwrap();
286 assert_eq!(count, 1);
287
288 migrator.run(migrations).await.unwrap();
289 let count_after: i64 =
290 sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
291 .bind(&version)
292 .fetch_one(&pool)
293 .await
294 .unwrap();
295 assert_eq!(count_after, 1);
296
297 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
298 .execute(&pool)
299 .await;
300 let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
301 .bind(&version)
302 .execute(&pool)
303 .await;
304 }
305
306 #[cfg(feature = "postgres")]
307 #[tokio::test]
308 async fn postgres_migrator_rolls_back_on_error() {
309 let Some(pool) = pg_pool_or_skip().await else {
310 return;
311 };
312 let migrator = Migrator::new(pool.clone());
313
314 let suffix = SystemTime::now()
315 .duration_since(UNIX_EPOCH)
316 .unwrap()
317 .as_nanos();
318 let version = format!("20260102{:020}", suffix);
319 let table_name = format!("premix_mig_bad_{}", suffix);
320 let migrations = vec![Migration {
321 version: version.clone(),
322 name: "bad_sql".to_string(),
323 up_sql: format!(
324 "CREATE TABLE {} (id SERIAL PRIMARY KEY); INVALID SQL",
325 table_name
326 ),
327 down_sql: format!("DROP TABLE {};", table_name),
328 }];
329
330 let err = migrator.run(migrations).await.unwrap_err();
331 assert!(err.to_string().contains("syntax"));
332
333 let table_exists: Option<String> =
334 sqlx::query_scalar("SELECT to_regclass('_premix_migrations')::text")
335 .fetch_one(&pool)
336 .await
337 .unwrap();
338 if table_exists.is_some() {
339 let count: i64 =
340 sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
341 .bind(&version)
342 .fetch_one(&pool)
343 .await
344 .unwrap();
345 assert_eq!(count, 0);
346 }
347
348 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
349 .execute(&pool)
350 .await;
351 }
352}