forge_runtime/migrations/
runner.rs1use forge_core::error::{ForgeError, Result};
17use sqlx::{PgPool, Postgres};
18use std::collections::HashSet;
19use std::path::Path;
20use tracing::{debug, info, warn};
21
22use super::builtin::extract_version;
23
24const MIGRATION_LOCK_ID: i64 = 0x464F524745; #[derive(Debug, Clone)]
30pub struct Migration {
31 pub name: String,
33 pub up_sql: String,
35 pub down_sql: Option<String>,
37}
38
39impl Migration {
40 pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
42 Self {
43 name: name.into(),
44 up_sql: sql.into(),
45 down_sql: None,
46 }
47 }
48
49 pub fn with_down(
51 name: impl Into<String>,
52 up_sql: impl Into<String>,
53 down_sql: impl Into<String>,
54 ) -> Self {
55 Self {
56 name: name.into(),
57 up_sql: up_sql.into(),
58 down_sql: Some(down_sql.into()),
59 }
60 }
61
62 pub fn parse(name: impl Into<String>, content: &str) -> Self {
64 let name = name.into();
65 let (up_sql, down_sql) = parse_migration_content(content);
66 Self {
67 name,
68 up_sql,
69 down_sql,
70 }
71 }
72}
73
74fn parse_migration_content(content: &str) -> (String, Option<String>) {
77 let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
79
80 for pattern in down_marker_patterns {
81 if let Some(idx) = content.find(pattern) {
82 let up_part = &content[..idx];
83 let down_part = &content[idx + pattern.len()..];
84
85 let up_sql = up_part
87 .replace("-- @up", "")
88 .replace("--@up", "")
89 .replace("-- @UP", "")
90 .replace("--@UP", "")
91 .trim()
92 .to_string();
93
94 let down_sql = down_part.trim().to_string();
95
96 if down_sql.is_empty() {
97 return (up_sql, None);
98 }
99 return (up_sql, Some(down_sql));
100 }
101 }
102
103 let up_sql = content
105 .replace("-- @up", "")
106 .replace("--@up", "")
107 .replace("-- @UP", "")
108 .replace("--@UP", "")
109 .trim()
110 .to_string();
111
112 (up_sql, None)
113}
114
115pub struct MigrationRunner {
117 pool: PgPool,
118}
119
120impl MigrationRunner {
121 pub fn new(pool: PgPool) -> Self {
122 Self { pool }
123 }
124
125 pub async fn run(&self, user_migrations: Vec<Migration>) -> Result<()> {
130 let mut lock_conn = self.acquire_lock_connection().await?;
132
133 let result = self.run_migrations_inner(user_migrations).await;
134
135 if let Err(e) = self.release_lock_connection(&mut lock_conn).await {
137 warn!("Failed to release migration lock: {}", e);
138 }
139
140 result
141 }
142
143 async fn run_migrations_inner(&self, user_migrations: Vec<Migration>) -> Result<()> {
144 self.ensure_migrations_table().await?;
146
147 let applied = self.get_applied_migrations().await?;
149 debug!("Already applied migrations: {:?}", applied);
150
151 let max_applied_version = self.get_max_system_version(&applied);
153 debug!("Max applied system version: {:?}", max_applied_version);
154
155 let system_migrations = super::builtin::get_system_migrations();
157 for sys_migration in system_migrations {
158 if let Some(max_ver) = max_applied_version
160 && sys_migration.version <= max_ver
161 {
162 debug!(
163 "Skipping system migration v{} (already at v{})",
164 sys_migration.version, max_ver
165 );
166 continue;
167 }
168
169 let migration = sys_migration.to_migration();
170 info!(
171 "Applying system migration: {} ({})",
172 migration.name, sys_migration.description
173 );
174 self.apply_migration(&migration).await?;
175 }
176
177 for migration in user_migrations {
179 if !applied.contains(&migration.name) {
180 self.apply_migration(&migration).await?;
181 }
182 }
183
184 Ok(())
185 }
186
187 fn get_max_system_version(&self, applied: &HashSet<String>) -> Option<u32> {
190 applied
191 .iter()
192 .filter_map(|name| extract_version(name))
193 .max()
194 }
195
196 async fn acquire_lock_connection(&self) -> Result<sqlx::pool::PoolConnection<Postgres>> {
197 debug!("Acquiring migration lock...");
198 let mut conn = self.pool.acquire().await.map_err(|e| {
199 ForgeError::Database(format!("Failed to acquire lock connection: {}", e))
200 })?;
201
202 sqlx::query("SELECT pg_advisory_lock($1)")
203 .bind(MIGRATION_LOCK_ID)
204 .execute(&mut *conn)
205 .await
206 .map_err(|e| {
207 ForgeError::Database(format!("Failed to acquire migration lock: {}", e))
208 })?;
209 debug!("Migration lock acquired");
210 Ok(conn)
211 }
212
213 async fn release_lock_connection(
214 &self,
215 conn: &mut sqlx::pool::PoolConnection<Postgres>,
216 ) -> Result<()> {
217 sqlx::query("SELECT pg_advisory_unlock($1)")
218 .bind(MIGRATION_LOCK_ID)
219 .execute(&mut **conn)
220 .await
221 .map_err(|e| {
222 ForgeError::Database(format!("Failed to release migration lock: {}", e))
223 })?;
224 debug!("Migration lock released");
225 Ok(())
226 }
227
228 async fn ensure_migrations_table(&self) -> Result<()> {
229 sqlx::query(
231 r#"
232 CREATE TABLE IF NOT EXISTS forge_migrations (
233 id SERIAL PRIMARY KEY,
234 name VARCHAR(255) UNIQUE NOT NULL,
235 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
236 down_sql TEXT
237 )
238 "#,
239 )
240 .execute(&self.pool)
241 .await
242 .map_err(|e| ForgeError::Database(format!("Failed to create migrations table: {}", e)))?;
243
244 sqlx::query(
246 r#"
247 ALTER TABLE forge_migrations
248 ADD COLUMN IF NOT EXISTS down_sql TEXT
249 "#,
250 )
251 .execute(&self.pool)
252 .await
253 .map_err(|e| ForgeError::Database(format!("Failed to add down_sql column: {}", e)))?;
254
255 Ok(())
256 }
257
258 async fn get_applied_migrations(&self) -> Result<HashSet<String>> {
259 let rows = sqlx::query!("SELECT name FROM forge_migrations")
260 .fetch_all(&self.pool)
261 .await
262 .map_err(|e| {
263 ForgeError::Database(format!("Failed to get applied migrations: {}", e))
264 })?;
265
266 Ok(rows.into_iter().map(|row| row.name).collect())
267 }
268
269 async fn apply_migration(&self, migration: &Migration) -> Result<()> {
270 info!("Applying migration: {}", migration.name);
271
272 let statements = split_sql_statements(&migration.up_sql);
274
275 for statement in statements {
276 let statement = statement.trim();
277
278 if statement.is_empty()
280 || statement.lines().all(|l| {
281 let l = l.trim();
282 l.is_empty() || l.starts_with("--")
283 })
284 {
285 continue;
286 }
287
288 sqlx::query(statement)
289 .execute(&self.pool)
290 .await
291 .map_err(|e| {
292 ForgeError::Database(format!(
293 "Failed to apply migration '{}': {}",
294 migration.name, e
295 ))
296 })?;
297 }
298
299 sqlx::query("INSERT INTO forge_migrations (name, down_sql) VALUES ($1, $2)")
301 .bind(&migration.name)
302 .bind(&migration.down_sql)
303 .execute(&self.pool)
304 .await
305 .map_err(|e| {
306 ForgeError::Database(format!(
307 "Failed to record migration '{}': {}",
308 migration.name, e
309 ))
310 })?;
311
312 info!("Migration applied: {}", migration.name);
313 Ok(())
314 }
315
316 pub async fn rollback(&self, count: usize) -> Result<Vec<String>> {
318 if count == 0 {
319 return Ok(Vec::new());
320 }
321
322 let mut lock_conn = self.acquire_lock_connection().await?;
324
325 let result = self.rollback_inner(count).await;
326
327 if let Err(e) = self.release_lock_connection(&mut lock_conn).await {
329 warn!("Failed to release migration lock: {}", e);
330 }
331
332 result
333 }
334
335 async fn rollback_inner(&self, count: usize) -> Result<Vec<String>> {
336 self.ensure_migrations_table().await?;
337
338 let rows = sqlx::query!(
340 "SELECT id, name, down_sql FROM forge_migrations ORDER BY id DESC LIMIT $1",
341 count as i32
342 )
343 .fetch_all(&self.pool)
344 .await
345 .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
346
347 if rows.is_empty() {
348 info!("No migrations to rollback");
349 return Ok(Vec::new());
350 }
351
352 let mut rolled_back = Vec::new();
353
354 for row in rows {
355 let id = row.id;
356 let name = row.name;
357 let down_sql = row.down_sql;
358 info!("Rolling back migration: {}", name);
359
360 if let Some(down) = down_sql {
361 let statements = split_sql_statements(&down);
363 for statement in statements {
364 let statement = statement.trim();
365 if statement.is_empty()
366 || statement.lines().all(|l| {
367 let l = l.trim();
368 l.is_empty() || l.starts_with("--")
369 })
370 {
371 continue;
372 }
373
374 sqlx::query(statement)
375 .execute(&self.pool)
376 .await
377 .map_err(|e| {
378 ForgeError::Database(format!(
379 "Failed to rollback migration '{}': {}",
380 name, e
381 ))
382 })?;
383 }
384 } else {
385 warn!("Migration '{}' has no down SQL, removing record only", name);
386 }
387
388 sqlx::query("DELETE FROM forge_migrations WHERE id = $1")
390 .bind(id)
391 .execute(&self.pool)
392 .await
393 .map_err(|e| {
394 ForgeError::Database(format!(
395 "Failed to remove migration record '{}': {}",
396 name, e
397 ))
398 })?;
399
400 info!("Rolled back migration: {}", name);
401 rolled_back.push(name);
402 }
403
404 Ok(rolled_back)
405 }
406
407 pub async fn status(&self, available: &[Migration]) -> Result<MigrationStatus> {
409 self.ensure_migrations_table().await?;
410
411 let applied = self.get_applied_migrations().await?;
412
413 let applied_list: Vec<AppliedMigration> = {
414 let rows = sqlx::query!(
415 "SELECT name, applied_at, down_sql FROM forge_migrations ORDER BY id ASC"
416 )
417 .fetch_all(&self.pool)
418 .await
419 .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
420
421 rows.into_iter()
422 .map(|row| AppliedMigration {
423 name: row.name,
424 applied_at: row.applied_at,
425 has_down: row.down_sql.is_some(),
426 })
427 .collect()
428 };
429
430 let pending: Vec<String> = available
431 .iter()
432 .filter(|m| !applied.contains(&m.name))
433 .map(|m| m.name.clone())
434 .collect();
435
436 Ok(MigrationStatus {
437 applied: applied_list,
438 pending,
439 })
440 }
441}
442
443#[derive(Debug, Clone)]
445pub struct AppliedMigration {
446 pub name: String,
447 pub applied_at: chrono::DateTime<chrono::Utc>,
448 pub has_down: bool,
449}
450
451#[derive(Debug, Clone)]
453pub struct MigrationStatus {
454 pub applied: Vec<AppliedMigration>,
455 pub pending: Vec<String>,
456}
457
458fn split_sql_statements(sql: &str) -> Vec<String> {
461 let mut statements = Vec::new();
462 let mut current = String::new();
463 let mut in_dollar_quote = false;
464 let mut dollar_tag = String::new();
465 let mut chars = sql.chars().peekable();
466
467 while let Some(c) = chars.next() {
468 current.push(c);
469
470 if c == '$' {
472 let mut potential_tag = String::from("$");
474
475 while let Some(&next_c) = chars.peek() {
477 if next_c == '$' {
478 potential_tag.push(chars.next().expect("peeked char"));
480 current.push('$');
481 break;
482 } else if next_c.is_alphanumeric() || next_c == '_' {
483 let c = chars.next().expect("peeked char");
484 potential_tag.push(c);
485 current.push(c);
486 } else {
487 break;
488 }
489 }
490
491 if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
493 if in_dollar_quote && potential_tag == dollar_tag {
494 in_dollar_quote = false;
496 dollar_tag.clear();
497 } else if !in_dollar_quote {
498 in_dollar_quote = true;
500 dollar_tag = potential_tag;
501 }
502 }
503 }
504
505 if c == ';' && !in_dollar_quote {
507 let stmt = current.trim().trim_end_matches(';').trim().to_string();
508 if !stmt.is_empty() {
509 statements.push(stmt);
510 }
511 current.clear();
512 }
513 }
514
515 let stmt = current.trim().trim_end_matches(';').trim().to_string();
517 if !stmt.is_empty() {
518 statements.push(stmt);
519 }
520
521 statements
522}
523
524pub fn load_migrations_from_dir(dir: &Path) -> Result<Vec<Migration>> {
532 if !dir.exists() {
533 debug!("Migrations directory does not exist: {:?}", dir);
534 return Ok(Vec::new());
535 }
536
537 let mut migrations = Vec::new();
538
539 let entries = std::fs::read_dir(dir).map_err(ForgeError::Io)?;
540
541 for entry in entries {
542 let entry = entry.map_err(ForgeError::Io)?;
543 let path = entry.path();
544
545 if path.extension().map(|e| e == "sql").unwrap_or(false) {
546 let name = path
547 .file_stem()
548 .and_then(|s| s.to_str())
549 .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
550 .to_string();
551
552 let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
553
554 migrations.push(Migration::parse(name, &content));
555 }
556 }
557
558 migrations.sort_by(|a, b| a.name.cmp(&b.name));
560
561 debug!("Loaded {} user migrations", migrations.len());
562 Ok(migrations)
563}
564
565#[cfg(test)]
566#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
567mod tests {
568 use super::*;
569 use std::fs;
570 use tempfile::TempDir;
571
572 #[test]
573 fn test_load_migrations_from_empty_dir() {
574 let dir = TempDir::new().unwrap();
575 let migrations = load_migrations_from_dir(dir.path()).unwrap();
576 assert!(migrations.is_empty());
577 }
578
579 #[test]
580 fn test_load_migrations_from_nonexistent_dir() {
581 let migrations = load_migrations_from_dir(Path::new("/nonexistent/path")).unwrap();
582 assert!(migrations.is_empty());
583 }
584
585 #[test]
586 fn test_load_migrations_sorted() {
587 let dir = TempDir::new().unwrap();
588
589 fs::write(dir.path().join("0002_second.sql"), "SELECT 2;").unwrap();
591 fs::write(dir.path().join("0001_first.sql"), "SELECT 1;").unwrap();
592 fs::write(dir.path().join("0003_third.sql"), "SELECT 3;").unwrap();
593
594 let migrations = load_migrations_from_dir(dir.path()).unwrap();
595 assert_eq!(migrations.len(), 3);
596 assert_eq!(migrations[0].name, "0001_first");
597 assert_eq!(migrations[1].name, "0002_second");
598 assert_eq!(migrations[2].name, "0003_third");
599 }
600
601 #[test]
602 fn test_load_migrations_ignores_non_sql() {
603 let dir = TempDir::new().unwrap();
604
605 fs::write(dir.path().join("0001_migration.sql"), "SELECT 1;").unwrap();
606 fs::write(dir.path().join("readme.txt"), "Not a migration").unwrap();
607 fs::write(dir.path().join("backup.sql.bak"), "Backup").unwrap();
608
609 let migrations = load_migrations_from_dir(dir.path()).unwrap();
610 assert_eq!(migrations.len(), 1);
611 assert_eq!(migrations[0].name, "0001_migration");
612 }
613
614 #[test]
615 fn test_migration_new() {
616 let m = Migration::new("test", "SELECT 1");
617 assert_eq!(m.name, "test");
618 assert_eq!(m.up_sql, "SELECT 1");
619 assert!(m.down_sql.is_none());
620 }
621
622 #[test]
623 fn test_migration_with_down() {
624 let m = Migration::with_down("test", "CREATE TABLE t()", "DROP TABLE t");
625 assert_eq!(m.name, "test");
626 assert_eq!(m.up_sql, "CREATE TABLE t()");
627 assert_eq!(m.down_sql, Some("DROP TABLE t".to_string()));
628 }
629
630 #[test]
631 fn test_migration_parse_up_only() {
632 let content = "CREATE TABLE users (id INT);";
633 let m = Migration::parse("0001_test", content);
634 assert_eq!(m.name, "0001_test");
635 assert_eq!(m.up_sql, "CREATE TABLE users (id INT);");
636 assert!(m.down_sql.is_none());
637 }
638
639 #[test]
640 fn test_migration_parse_with_markers() {
641 let content = r#"
642-- @up
643CREATE TABLE users (
644 id UUID PRIMARY KEY,
645 email VARCHAR(255)
646);
647
648-- @down
649DROP TABLE users;
650"#;
651 let m = Migration::parse("0001_users", content);
652 assert_eq!(m.name, "0001_users");
653 assert!(m.up_sql.contains("CREATE TABLE users"));
654 assert!(!m.up_sql.contains("@up"));
655 assert!(!m.up_sql.contains("DROP TABLE"));
656 assert_eq!(m.down_sql, Some("DROP TABLE users;".to_string()));
657 }
658
659 #[test]
660 fn test_migration_parse_complex() {
661 let content = r#"
662-- @up
663CREATE TABLE posts (
664 id UUID PRIMARY KEY,
665 title TEXT NOT NULL
666);
667CREATE INDEX idx_posts_title ON posts(title);
668
669-- @down
670DROP INDEX idx_posts_title;
671DROP TABLE posts;
672"#;
673 let m = Migration::parse("0002_posts", content);
674 assert!(m.up_sql.contains("CREATE TABLE posts"));
675 assert!(m.up_sql.contains("CREATE INDEX"));
676 let down = m.down_sql.unwrap();
677 assert!(down.contains("DROP INDEX"));
678 assert!(down.contains("DROP TABLE posts"));
679 }
680
681 #[test]
682 fn test_split_simple_statements() {
683 let sql = "SELECT 1; SELECT 2; SELECT 3;";
684 let stmts = super::split_sql_statements(sql);
685 assert_eq!(stmts.len(), 3);
686 assert_eq!(stmts[0], "SELECT 1");
687 assert_eq!(stmts[1], "SELECT 2");
688 assert_eq!(stmts[2], "SELECT 3");
689 }
690
691 #[test]
692 fn test_split_with_dollar_quoted_function() {
693 let sql = r#"
694CREATE FUNCTION test() RETURNS void AS $$
695BEGIN
696 SELECT 1;
697 SELECT 2;
698END;
699$$ LANGUAGE plpgsql;
700
701SELECT 3;
702"#;
703 let stmts = super::split_sql_statements(sql);
704 assert_eq!(stmts.len(), 2);
705 assert!(stmts[0].contains("CREATE FUNCTION"));
706 assert!(stmts[0].contains("$$ LANGUAGE plpgsql"));
707 assert!(stmts[1].contains("SELECT 3"));
708 }
709
710 #[test]
711 fn test_split_preserves_dollar_quote_content() {
712 let sql = r#"
713CREATE FUNCTION notify() RETURNS trigger AS $$
714DECLARE
715 row_id TEXT;
716BEGIN
717 row_id := NEW.id::TEXT;
718 RETURN NEW;
719END;
720$$ LANGUAGE plpgsql;
721"#;
722 let stmts = super::split_sql_statements(sql);
723 assert_eq!(stmts.len(), 1);
724 assert!(stmts[0].contains("row_id := NEW.id::TEXT"));
725 }
726}