forge_runtime/migrations/
runner.rs1use forge_core::error::{ForgeError, Result};
16use sqlx::{PgPool, Postgres};
17use std::collections::HashSet;
18use std::path::Path;
19use tracing::{debug, info, warn};
20
21use super::builtin::extract_version;
22
23const MIGRATION_LOCK_ID: i64 = 0x464F524745; #[derive(Debug, Clone)]
29pub struct Migration {
30 pub name: String,
32 pub up_sql: String,
34 pub down_sql: Option<String>,
36}
37
38impl Migration {
39 pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
41 Self {
42 name: name.into(),
43 up_sql: sql.into(),
44 down_sql: None,
45 }
46 }
47
48 pub fn with_down(
50 name: impl Into<String>,
51 up_sql: impl Into<String>,
52 down_sql: impl Into<String>,
53 ) -> Self {
54 Self {
55 name: name.into(),
56 up_sql: up_sql.into(),
57 down_sql: Some(down_sql.into()),
58 }
59 }
60
61 pub fn parse(name: impl Into<String>, content: &str) -> Self {
63 let name = name.into();
64 let (up_sql, down_sql) = parse_migration_content(content);
65 Self {
66 name,
67 up_sql,
68 down_sql,
69 }
70 }
71}
72
73fn parse_migration_content(content: &str) -> (String, Option<String>) {
76 let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
78
79 for pattern in down_marker_patterns {
80 if let Some(idx) = content.find(pattern) {
81 let up_part = &content[..idx];
82 let down_part = &content[idx + pattern.len()..];
83
84 let up_sql = up_part
86 .replace("-- @up", "")
87 .replace("--@up", "")
88 .replace("-- @UP", "")
89 .replace("--@UP", "")
90 .trim()
91 .to_string();
92
93 let down_sql = down_part.trim().to_string();
94
95 if down_sql.is_empty() {
96 return (up_sql, None);
97 }
98 return (up_sql, Some(down_sql));
99 }
100 }
101
102 let up_sql = content
104 .replace("-- @up", "")
105 .replace("--@up", "")
106 .replace("-- @UP", "")
107 .replace("--@UP", "")
108 .trim()
109 .to_string();
110
111 (up_sql, None)
112}
113
114pub struct MigrationRunner {
116 pool: PgPool,
117}
118
119impl MigrationRunner {
120 pub fn new(pool: PgPool) -> Self {
121 Self { pool }
122 }
123
124 pub async fn run(&self, user_migrations: Vec<Migration>) -> Result<()> {
129 let mut lock_conn = self.acquire_lock_connection().await?;
131
132 let result = self.run_migrations_inner(user_migrations).await;
133
134 if let Err(e) = self.release_lock_connection(&mut lock_conn).await {
136 warn!("Failed to release migration lock: {}", e);
137 }
138
139 result
140 }
141
142 async fn run_migrations_inner(&self, user_migrations: Vec<Migration>) -> Result<()> {
143 self.ensure_migrations_table().await?;
145
146 let applied = self.get_applied_migrations().await?;
148 debug!("Already applied migrations: {:?}", applied);
149
150 let max_applied_version = self.get_max_system_version(&applied);
152 debug!("Max applied system version: {:?}", max_applied_version);
153
154 let system_migrations = super::builtin::get_system_migrations();
156 for sys_migration in system_migrations {
157 if let Some(max_ver) = max_applied_version
159 && sys_migration.version <= max_ver
160 {
161 debug!(
162 "Skipping system migration v{} (already at v{})",
163 sys_migration.version, max_ver
164 );
165 continue;
166 }
167
168 let migration = sys_migration.to_migration();
169 info!(
170 "Applying system migration: {} ({})",
171 migration.name, sys_migration.description
172 );
173 self.apply_migration(&migration).await?;
174 }
175
176 for migration in user_migrations {
178 if !applied.contains(&migration.name) {
179 self.apply_migration(&migration).await?;
180 }
181 }
182
183 Ok(())
184 }
185
186 fn get_max_system_version(&self, applied: &HashSet<String>) -> Option<u32> {
188 applied
189 .iter()
190 .filter_map(|name| extract_version(name))
191 .max()
192 }
193
194 async fn acquire_lock_connection(&self) -> Result<sqlx::pool::PoolConnection<Postgres>> {
195 debug!("Acquiring migration lock...");
196 let mut conn = self.pool.acquire().await.map_err(|e| {
197 ForgeError::Database(format!("Failed to acquire lock connection: {}", e))
198 })?;
199
200 sqlx::query_scalar!("SELECT pg_advisory_lock($1)", MIGRATION_LOCK_ID)
201 .fetch_one(&mut *conn)
202 .await
203 .map_err(|e| {
204 ForgeError::Database(format!("Failed to acquire migration lock: {}", e))
205 })?;
206 debug!("Migration lock acquired");
207 Ok(conn)
208 }
209
210 async fn release_lock_connection(
211 &self,
212 conn: &mut sqlx::pool::PoolConnection<Postgres>,
213 ) -> Result<()> {
214 sqlx::query_scalar!("SELECT pg_advisory_unlock($1)", MIGRATION_LOCK_ID)
215 .fetch_one(&mut **conn)
216 .await
217 .map_err(|e| {
218 ForgeError::Database(format!("Failed to release migration lock: {}", e))
219 })?;
220 debug!("Migration lock released");
221 Ok(())
222 }
223
224 async fn ensure_migrations_table(&self) -> Result<()> {
225 sqlx::query(
227 r#"
228 CREATE TABLE IF NOT EXISTS forge_migrations (
229 id SERIAL PRIMARY KEY,
230 name VARCHAR(255) UNIQUE NOT NULL,
231 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
232 down_sql TEXT
233 )
234 "#,
235 )
236 .execute(&self.pool)
237 .await
238 .map_err(|e| ForgeError::Database(format!("Failed to create migrations table: {}", e)))?;
239
240 Ok(())
241 }
242
243 async fn get_applied_migrations(&self) -> Result<HashSet<String>> {
244 let rows = sqlx::query!("SELECT name FROM forge_migrations")
245 .fetch_all(&self.pool)
246 .await
247 .map_err(|e| {
248 ForgeError::Database(format!("Failed to get applied migrations: {}", e))
249 })?;
250
251 Ok(rows.into_iter().map(|row| row.name).collect())
252 }
253
254 async fn apply_migration(&self, migration: &Migration) -> Result<()> {
255 info!("Applying migration: {}", migration.name);
256
257 let statements = split_sql_statements(&migration.up_sql);
259
260 for statement in statements {
261 let statement = statement.trim();
262
263 if statement.is_empty()
265 || statement.lines().all(|l| {
266 let l = l.trim();
267 l.is_empty() || l.starts_with("--")
268 })
269 {
270 continue;
271 }
272
273 sqlx::query(statement)
274 .execute(&self.pool)
275 .await
276 .map_err(|e| {
277 ForgeError::Database(format!(
278 "Failed to apply migration '{}': {}",
279 migration.name, e
280 ))
281 })?;
282 }
283
284 sqlx::query!(
286 "INSERT INTO forge_migrations (name, down_sql) VALUES ($1, $2)",
287 &migration.name,
288 migration.down_sql as _,
289 )
290 .execute(&self.pool)
291 .await
292 .map_err(|e| {
293 ForgeError::Database(format!(
294 "Failed to record migration '{}': {}",
295 migration.name, e
296 ))
297 })?;
298
299 info!("Migration applied: {}", migration.name);
300 Ok(())
301 }
302
303 pub async fn rollback(&self, count: usize) -> Result<Vec<String>> {
305 if count == 0 {
306 return Ok(Vec::new());
307 }
308
309 let mut lock_conn = self.acquire_lock_connection().await?;
311
312 let result = self.rollback_inner(count).await;
313
314 if let Err(e) = self.release_lock_connection(&mut lock_conn).await {
316 warn!("Failed to release migration lock: {}", e);
317 }
318
319 result
320 }
321
322 async fn rollback_inner(&self, count: usize) -> Result<Vec<String>> {
323 self.ensure_migrations_table().await?;
324
325 let rows = sqlx::query!(
327 "SELECT id, name, down_sql FROM forge_migrations ORDER BY id DESC LIMIT $1",
328 count as i32
329 )
330 .fetch_all(&self.pool)
331 .await
332 .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
333
334 if rows.is_empty() {
335 info!("No migrations to rollback");
336 return Ok(Vec::new());
337 }
338
339 let mut rolled_back = Vec::new();
340
341 for row in rows {
342 let id = row.id;
343 let name = row.name;
344 let down_sql = row.down_sql;
345 info!("Rolling back migration: {}", name);
346
347 if let Some(down) = down_sql {
348 let statements = split_sql_statements(&down);
350 for statement in statements {
351 let statement = statement.trim();
352 if statement.is_empty()
353 || statement.lines().all(|l| {
354 let l = l.trim();
355 l.is_empty() || l.starts_with("--")
356 })
357 {
358 continue;
359 }
360
361 sqlx::query(statement)
362 .execute(&self.pool)
363 .await
364 .map_err(|e| {
365 ForgeError::Database(format!(
366 "Failed to rollback migration '{}': {}",
367 name, e
368 ))
369 })?;
370 }
371 } else {
372 warn!("Migration '{}' has no down SQL, removing record only", name);
373 }
374
375 sqlx::query!("DELETE FROM forge_migrations WHERE id = $1", id)
377 .execute(&self.pool)
378 .await
379 .map_err(|e| {
380 ForgeError::Database(format!(
381 "Failed to remove migration record '{}': {}",
382 name, e
383 ))
384 })?;
385
386 info!("Rolled back migration: {}", name);
387 rolled_back.push(name);
388 }
389
390 Ok(rolled_back)
391 }
392
393 pub async fn status(&self, available: &[Migration]) -> Result<MigrationStatus> {
395 self.ensure_migrations_table().await?;
396
397 let applied = self.get_applied_migrations().await?;
398
399 let applied_list: Vec<AppliedMigration> = {
400 let rows = sqlx::query!(
401 "SELECT name, applied_at, down_sql FROM forge_migrations ORDER BY id ASC"
402 )
403 .fetch_all(&self.pool)
404 .await
405 .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
406
407 rows.into_iter()
408 .map(|row| AppliedMigration {
409 name: row.name,
410 applied_at: row.applied_at,
411 has_down: row.down_sql.is_some(),
412 })
413 .collect()
414 };
415
416 let pending: Vec<String> = available
417 .iter()
418 .filter(|m| !applied.contains(&m.name))
419 .map(|m| m.name.clone())
420 .collect();
421
422 Ok(MigrationStatus {
423 applied: applied_list,
424 pending,
425 })
426 }
427}
428
429#[derive(Debug, Clone)]
431pub struct AppliedMigration {
432 pub name: String,
433 pub applied_at: chrono::DateTime<chrono::Utc>,
434 pub has_down: bool,
435}
436
437#[derive(Debug, Clone)]
439pub struct MigrationStatus {
440 pub applied: Vec<AppliedMigration>,
441 pub pending: Vec<String>,
442}
443
444fn split_sql_statements(sql: &str) -> Vec<String> {
447 let mut statements = Vec::new();
448 let mut current = String::new();
449 let mut in_dollar_quote = false;
450 let mut dollar_tag = String::new();
451 let mut chars = sql.chars().peekable();
452
453 while let Some(c) = chars.next() {
454 current.push(c);
455
456 if c == '$' {
458 let mut potential_tag = String::from("$");
460
461 while let Some(&next_c) = chars.peek() {
463 if next_c == '$' {
464 potential_tag.push(chars.next().expect("peeked char"));
466 current.push('$');
467 break;
468 } else if next_c.is_alphanumeric() || next_c == '_' {
469 let c = chars.next().expect("peeked char");
470 potential_tag.push(c);
471 current.push(c);
472 } else {
473 break;
474 }
475 }
476
477 if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
479 if in_dollar_quote && potential_tag == dollar_tag {
480 in_dollar_quote = false;
482 dollar_tag.clear();
483 } else if !in_dollar_quote {
484 in_dollar_quote = true;
486 dollar_tag = potential_tag;
487 }
488 }
489 }
490
491 if c == ';' && !in_dollar_quote {
493 let stmt = current.trim().trim_end_matches(';').trim().to_string();
494 if !stmt.is_empty() {
495 statements.push(stmt);
496 }
497 current.clear();
498 }
499 }
500
501 let stmt = current.trim().trim_end_matches(';').trim().to_string();
503 if !stmt.is_empty() {
504 statements.push(stmt);
505 }
506
507 statements
508}
509
510pub fn load_migrations_from_dir(dir: &Path) -> Result<Vec<Migration>> {
518 if !dir.exists() {
519 debug!("Migrations directory does not exist: {:?}", dir);
520 return Ok(Vec::new());
521 }
522
523 let mut migrations = Vec::new();
524
525 let entries = std::fs::read_dir(dir).map_err(ForgeError::Io)?;
526
527 for entry in entries {
528 let entry = entry.map_err(ForgeError::Io)?;
529 let path = entry.path();
530
531 if path.extension().map(|e| e == "sql").unwrap_or(false) {
532 let name = path
533 .file_stem()
534 .and_then(|s| s.to_str())
535 .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
536 .to_string();
537
538 let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
539
540 migrations.push(Migration::parse(name, &content));
541 }
542 }
543
544 migrations.sort_by(|a, b| a.name.cmp(&b.name));
546
547 debug!("Loaded {} user migrations", migrations.len());
548 Ok(migrations)
549}
550
551#[cfg(test)]
552#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
553mod tests {
554 use super::*;
555 use std::collections::HashSet;
556 use std::fs;
557 use tempfile::TempDir;
558
559 #[test]
560 fn test_load_migrations_from_empty_dir() {
561 let dir = TempDir::new().unwrap();
562 let migrations = load_migrations_from_dir(dir.path()).unwrap();
563 assert!(migrations.is_empty());
564 }
565
566 #[test]
567 fn test_load_migrations_from_nonexistent_dir() {
568 let migrations = load_migrations_from_dir(Path::new("/nonexistent/path")).unwrap();
569 assert!(migrations.is_empty());
570 }
571
572 #[test]
573 fn test_load_migrations_sorted() {
574 let dir = TempDir::new().unwrap();
575
576 fs::write(dir.path().join("0002_second.sql"), "SELECT 2;").unwrap();
578 fs::write(dir.path().join("0001_first.sql"), "SELECT 1;").unwrap();
579 fs::write(dir.path().join("0003_third.sql"), "SELECT 3;").unwrap();
580
581 let migrations = load_migrations_from_dir(dir.path()).unwrap();
582 assert_eq!(migrations.len(), 3);
583 assert_eq!(migrations[0].name, "0001_first");
584 assert_eq!(migrations[1].name, "0002_second");
585 assert_eq!(migrations[2].name, "0003_third");
586 }
587
588 #[test]
589 fn test_load_migrations_ignores_non_sql() {
590 let dir = TempDir::new().unwrap();
591
592 fs::write(dir.path().join("0001_migration.sql"), "SELECT 1;").unwrap();
593 fs::write(dir.path().join("readme.txt"), "Not a migration").unwrap();
594 fs::write(dir.path().join("backup.sql.bak"), "Backup").unwrap();
595
596 let migrations = load_migrations_from_dir(dir.path()).unwrap();
597 assert_eq!(migrations.len(), 1);
598 assert_eq!(migrations[0].name, "0001_migration");
599 }
600
601 #[test]
602 fn test_migration_new() {
603 let m = Migration::new("test", "SELECT 1");
604 assert_eq!(m.name, "test");
605 assert_eq!(m.up_sql, "SELECT 1");
606 assert!(m.down_sql.is_none());
607 }
608
609 #[test]
610 fn test_migration_with_down() {
611 let m = Migration::with_down("test", "CREATE TABLE t()", "DROP TABLE t");
612 assert_eq!(m.name, "test");
613 assert_eq!(m.up_sql, "CREATE TABLE t()");
614 assert_eq!(m.down_sql, Some("DROP TABLE t".to_string()));
615 }
616
617 #[test]
618 fn test_migration_parse_up_only() {
619 let content = "CREATE TABLE users (id INT);";
620 let m = Migration::parse("0001_test", content);
621 assert_eq!(m.name, "0001_test");
622 assert_eq!(m.up_sql, "CREATE TABLE users (id INT);");
623 assert!(m.down_sql.is_none());
624 }
625
626 #[test]
627 fn test_migration_parse_with_markers() {
628 let content = r#"
629-- @up
630CREATE TABLE users (
631 id UUID PRIMARY KEY,
632 email VARCHAR(255)
633);
634
635-- @down
636DROP TABLE users;
637"#;
638 let m = Migration::parse("0001_users", content);
639 assert_eq!(m.name, "0001_users");
640 assert!(m.up_sql.contains("CREATE TABLE users"));
641 assert!(!m.up_sql.contains("@up"));
642 assert!(!m.up_sql.contains("DROP TABLE"));
643 assert_eq!(m.down_sql, Some("DROP TABLE users;".to_string()));
644 }
645
646 #[test]
647 fn test_migration_parse_complex() {
648 let content = r#"
649-- @up
650CREATE TABLE posts (
651 id UUID PRIMARY KEY,
652 title TEXT NOT NULL
653);
654CREATE INDEX idx_posts_title ON posts(title);
655
656-- @down
657DROP INDEX idx_posts_title;
658DROP TABLE posts;
659"#;
660 let m = Migration::parse("0002_posts", content);
661 assert!(m.up_sql.contains("CREATE TABLE posts"));
662 assert!(m.up_sql.contains("CREATE INDEX"));
663 let down = m.down_sql.unwrap();
664 assert!(down.contains("DROP INDEX"));
665 assert!(down.contains("DROP TABLE posts"));
666 }
667
668 #[tokio::test]
669 async fn test_get_max_system_version_prefers_highest_applied_version() {
670 let pool = sqlx::postgres::PgPoolOptions::new()
671 .max_connections(1)
672 .connect_lazy("postgres://localhost/nonexistent")
673 .expect("lazy pool must build");
674 let runner = MigrationRunner::new(pool);
675
676 let applied = HashSet::from([
677 "__forge_v003".to_string(),
678 "__forge_v001".to_string(),
679 "0001_user_schema".to_string(),
680 ]);
681
682 assert_eq!(runner.get_max_system_version(&applied), Some(3));
683 }
684
685 #[test]
686 fn test_split_simple_statements() {
687 let sql = "SELECT 1; SELECT 2; SELECT 3;";
688 let stmts = super::split_sql_statements(sql);
689 assert_eq!(stmts.len(), 3);
690 assert_eq!(stmts[0], "SELECT 1");
691 assert_eq!(stmts[1], "SELECT 2");
692 assert_eq!(stmts[2], "SELECT 3");
693 }
694
695 #[test]
696 fn test_split_with_dollar_quoted_function() {
697 let sql = r#"
698CREATE FUNCTION test() RETURNS void AS $$
699BEGIN
700 SELECT 1;
701 SELECT 2;
702END;
703$$ LANGUAGE plpgsql;
704
705SELECT 3;
706"#;
707 let stmts = super::split_sql_statements(sql);
708 assert_eq!(stmts.len(), 2);
709 assert!(stmts[0].contains("CREATE FUNCTION"));
710 assert!(stmts[0].contains("$$ LANGUAGE plpgsql"));
711 assert!(stmts[1].contains("SELECT 3"));
712 }
713
714 #[test]
715 fn test_split_preserves_dollar_quote_content() {
716 let sql = r#"
717CREATE FUNCTION notify() RETURNS trigger AS $$
718DECLARE
719 row_id TEXT;
720BEGIN
721 row_id := NEW.id::TEXT;
722 RETURN NEW;
723END;
724$$ LANGUAGE plpgsql;
725"#;
726 let stmts = super::split_sql_statements(sql);
727 assert_eq!(stmts.len(), 1);
728 assert!(stmts[0].contains("row_id := NEW.id::TEXT"));
729 }
730}