Skip to main content

forge_runtime/migrations/
runner.rs

1//! Migration runner with mesh-safe locking.
2//!
3//! Ensures only one node runs migrations at a time using PostgreSQL advisory locks.
4//!
5//! # Migration Types
6//!
7//! This runner handles two types of migrations:
8//!
9//! 1. **System migrations** (`__forge_vXXX`): Internal FORGE schema changes.
10//!    These are versioned numerically and always run before user migrations.
11//!    Legacy installations may have `0000_forge_internal` which is treated as v001.
12//!
13//! 2. **User migrations** (`XXXX_name.sql`): Application-specific schema changes.
14//!    These are sorted alphabetically by name.
15
16use forge_core::error::{ForgeError, Result};
17use sqlx::{PgPool, Postgres};
18use std::collections::HashSet;
19use std::path::Path;
20use tracing::{debug, info, warn};
21
22use super::builtin::extract_version;
23
24/// Lock ID for migration advisory lock (arbitrary but consistent).
25/// Using a fixed value derived from "FORGE" ascii values.
26const MIGRATION_LOCK_ID: i64 = 0x464F524745; // "FORGE" in hex
27
28/// A single migration with up and optional down SQL.
29#[derive(Debug, Clone)]
30pub struct Migration {
31    /// Unique name/identifier (e.g., "0001_forge_internal" or "0002_create_users").
32    pub name: String,
33    /// SQL to execute for upgrade (forward migration).
34    pub up_sql: String,
35    /// SQL to execute for rollback (optional).
36    pub down_sql: Option<String>,
37}
38
39impl Migration {
40    /// Create a migration with only up SQL (no rollback).
41    pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
42        Self {
43            name: name.into(),
44            up_sql: sql.into(),
45            down_sql: None,
46        }
47    }
48
49    /// Create a migration with both up and down SQL.
50    pub fn with_down(
51        name: impl Into<String>,
52        up_sql: impl Into<String>,
53        down_sql: impl Into<String>,
54    ) -> Self {
55        Self {
56            name: name.into(),
57            up_sql: up_sql.into(),
58            down_sql: Some(down_sql.into()),
59        }
60    }
61
62    /// Parse migration content that may contain -- @up and -- @down markers.
63    pub fn parse(name: impl Into<String>, content: &str) -> Self {
64        let name = name.into();
65        let (up_sql, down_sql) = parse_migration_content(content);
66        Self {
67            name,
68            up_sql,
69            down_sql,
70        }
71    }
72}
73
74/// Parse migration content, splitting on -- @down marker.
75/// Returns (up_sql, Option<down_sql>).
76fn parse_migration_content(content: &str) -> (String, Option<String>) {
77    // Look for -- @down marker (case insensitive, with optional whitespace)
78    let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
79
80    for pattern in down_marker_patterns {
81        if let Some(idx) = content.find(pattern) {
82            let up_part = &content[..idx];
83            let down_part = &content[idx + pattern.len()..];
84
85            // Clean up the up part (remove -- @up marker if present)
86            let up_sql = up_part
87                .replace("-- @up", "")
88                .replace("--@up", "")
89                .replace("-- @UP", "")
90                .replace("--@UP", "")
91                .trim()
92                .to_string();
93
94            let down_sql = down_part.trim().to_string();
95
96            if down_sql.is_empty() {
97                return (up_sql, None);
98            }
99            return (up_sql, Some(down_sql));
100        }
101    }
102
103    // No @down marker found - treat entire content as up SQL
104    let up_sql = content
105        .replace("-- @up", "")
106        .replace("--@up", "")
107        .replace("-- @UP", "")
108        .replace("--@UP", "")
109        .trim()
110        .to_string();
111
112    (up_sql, None)
113}
114
115/// Migration runner that handles both built-in and user migrations.
116pub struct MigrationRunner {
117    pool: PgPool,
118}
119
120impl MigrationRunner {
121    pub fn new(pool: PgPool) -> Self {
122        Self { pool }
123    }
124
125    /// Run all pending migrations with mesh-safe locking.
126    ///
127    /// This acquires an exclusive advisory lock before running migrations,
128    /// ensuring only one node in the cluster runs migrations at a time.
129    pub async fn run(&self, user_migrations: Vec<Migration>) -> Result<()> {
130        // Acquire exclusive lock (blocks until acquired) on a dedicated connection.
131        let mut lock_conn = self.acquire_lock_connection().await?;
132
133        let result = self.run_migrations_inner(user_migrations).await;
134
135        // Always release lock, even on error
136        if let Err(e) = self.release_lock_connection(&mut lock_conn).await {
137            warn!("Failed to release migration lock: {}", e);
138        }
139
140        result
141    }
142
143    async fn run_migrations_inner(&self, user_migrations: Vec<Migration>) -> Result<()> {
144        // Ensure migration tracking table exists
145        self.ensure_migrations_table().await?;
146
147        // Get already-applied migrations
148        let applied = self.get_applied_migrations().await?;
149        debug!("Already applied migrations: {:?}", applied);
150
151        // Calculate the highest system version already applied
152        let max_applied_version = self.get_max_system_version(&applied);
153        debug!("Max applied system version: {:?}", max_applied_version);
154
155        // Run built-in FORGE system migrations first (in version order)
156        let system_migrations = super::builtin::get_system_migrations();
157        for sys_migration in system_migrations {
158            // Skip if this version (or equivalent legacy) is already applied
159            if let Some(max_ver) = max_applied_version
160                && sys_migration.version <= max_ver
161            {
162                debug!(
163                    "Skipping system migration v{} (already at v{})",
164                    sys_migration.version, max_ver
165                );
166                continue;
167            }
168
169            let migration = sys_migration.to_migration();
170            info!(
171                "Applying system migration: {} ({})",
172                migration.name, sys_migration.description
173            );
174            self.apply_migration(&migration).await?;
175        }
176
177        // Then run user migrations (sorted by name)
178        for migration in user_migrations {
179            if !applied.contains(&migration.name) {
180                self.apply_migration(&migration).await?;
181            }
182        }
183
184        Ok(())
185    }
186
187    /// Get the maximum system migration version that has been applied.
188    /// Considers both new-style `__forge_vXXX` and legacy `0000_forge_internal`.
189    fn get_max_system_version(&self, applied: &HashSet<String>) -> Option<u32> {
190        applied
191            .iter()
192            .filter_map(|name| extract_version(name))
193            .max()
194    }
195
196    async fn acquire_lock_connection(&self) -> Result<sqlx::pool::PoolConnection<Postgres>> {
197        debug!("Acquiring migration lock...");
198        let mut conn = self.pool.acquire().await.map_err(|e| {
199            ForgeError::Database(format!("Failed to acquire lock connection: {}", e))
200        })?;
201
202        sqlx::query("SELECT pg_advisory_lock($1)")
203            .bind(MIGRATION_LOCK_ID)
204            .execute(&mut *conn)
205            .await
206            .map_err(|e| {
207                ForgeError::Database(format!("Failed to acquire migration lock: {}", e))
208            })?;
209        debug!("Migration lock acquired");
210        Ok(conn)
211    }
212
213    async fn release_lock_connection(
214        &self,
215        conn: &mut sqlx::pool::PoolConnection<Postgres>,
216    ) -> Result<()> {
217        sqlx::query("SELECT pg_advisory_unlock($1)")
218            .bind(MIGRATION_LOCK_ID)
219            .execute(&mut **conn)
220            .await
221            .map_err(|e| {
222                ForgeError::Database(format!("Failed to release migration lock: {}", e))
223            })?;
224        debug!("Migration lock released");
225        Ok(())
226    }
227
228    async fn ensure_migrations_table(&self) -> Result<()> {
229        // Create table if not exists
230        sqlx::query(
231            r#"
232            CREATE TABLE IF NOT EXISTS forge_migrations (
233                id SERIAL PRIMARY KEY,
234                name VARCHAR(255) UNIQUE NOT NULL,
235                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
236                down_sql TEXT
237            )
238            "#,
239        )
240        .execute(&self.pool)
241        .await
242        .map_err(|e| ForgeError::Database(format!("Failed to create migrations table: {}", e)))?;
243
244        // Add down_sql column if it doesn't exist (for existing installations)
245        sqlx::query(
246            r#"
247            ALTER TABLE forge_migrations
248            ADD COLUMN IF NOT EXISTS down_sql TEXT
249            "#,
250        )
251        .execute(&self.pool)
252        .await
253        .map_err(|e| ForgeError::Database(format!("Failed to add down_sql column: {}", e)))?;
254
255        Ok(())
256    }
257
258    async fn get_applied_migrations(&self) -> Result<HashSet<String>> {
259        let rows: Vec<(String,)> = sqlx::query_as("SELECT name FROM forge_migrations")
260            .fetch_all(&self.pool)
261            .await
262            .map_err(|e| {
263                ForgeError::Database(format!("Failed to get applied migrations: {}", e))
264            })?;
265
266        Ok(rows.into_iter().map(|(name,)| name).collect())
267    }
268
269    async fn apply_migration(&self, migration: &Migration) -> Result<()> {
270        info!("Applying migration: {}", migration.name);
271
272        // Split migration into individual statements, respecting dollar-quoted strings
273        let statements = split_sql_statements(&migration.up_sql);
274
275        for statement in statements {
276            let statement = statement.trim();
277
278            // Skip empty statements or comment-only blocks
279            if statement.is_empty()
280                || statement.lines().all(|l| {
281                    let l = l.trim();
282                    l.is_empty() || l.starts_with("--")
283                })
284            {
285                continue;
286            }
287
288            sqlx::query(statement)
289                .execute(&self.pool)
290                .await
291                .map_err(|e| {
292                    ForgeError::Database(format!(
293                        "Failed to apply migration '{}': {}",
294                        migration.name, e
295                    ))
296                })?;
297        }
298
299        // Record it as applied (with down_sql for potential rollback)
300        sqlx::query("INSERT INTO forge_migrations (name, down_sql) VALUES ($1, $2)")
301            .bind(&migration.name)
302            .bind(&migration.down_sql)
303            .execute(&self.pool)
304            .await
305            .map_err(|e| {
306                ForgeError::Database(format!(
307                    "Failed to record migration '{}': {}",
308                    migration.name, e
309                ))
310            })?;
311
312        info!("Migration applied: {}", migration.name);
313        Ok(())
314    }
315
316    /// Rollback N migrations (most recent first).
317    pub async fn rollback(&self, count: usize) -> Result<Vec<String>> {
318        if count == 0 {
319            return Ok(Vec::new());
320        }
321
322        // Acquire exclusive lock on a dedicated connection.
323        let mut lock_conn = self.acquire_lock_connection().await?;
324
325        let result = self.rollback_inner(count).await;
326
327        // Always release lock
328        if let Err(e) = self.release_lock_connection(&mut lock_conn).await {
329            warn!("Failed to release migration lock: {}", e);
330        }
331
332        result
333    }
334
335    async fn rollback_inner(&self, count: usize) -> Result<Vec<String>> {
336        self.ensure_migrations_table().await?;
337
338        // Get the N most recent migrations with their down_sql
339        let rows: Vec<(i32, String, Option<String>)> = sqlx::query_as(
340            "SELECT id, name, down_sql FROM forge_migrations ORDER BY id DESC LIMIT $1",
341        )
342        .bind(count as i32)
343        .fetch_all(&self.pool)
344        .await
345        .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
346
347        if rows.is_empty() {
348            info!("No migrations to rollback");
349            return Ok(Vec::new());
350        }
351
352        let mut rolled_back = Vec::new();
353
354        for (id, name, down_sql) in rows {
355            info!("Rolling back migration: {}", name);
356
357            if let Some(down) = down_sql {
358                // Execute down SQL
359                let statements = split_sql_statements(&down);
360                for statement in statements {
361                    let statement = statement.trim();
362                    if statement.is_empty()
363                        || statement.lines().all(|l| {
364                            let l = l.trim();
365                            l.is_empty() || l.starts_with("--")
366                        })
367                    {
368                        continue;
369                    }
370
371                    sqlx::query(statement)
372                        .execute(&self.pool)
373                        .await
374                        .map_err(|e| {
375                            ForgeError::Database(format!(
376                                "Failed to rollback migration '{}': {}",
377                                name, e
378                            ))
379                        })?;
380                }
381            } else {
382                warn!("Migration '{}' has no down SQL, removing record only", name);
383            }
384
385            // Remove from migrations table
386            sqlx::query("DELETE FROM forge_migrations WHERE id = $1")
387                .bind(id)
388                .execute(&self.pool)
389                .await
390                .map_err(|e| {
391                    ForgeError::Database(format!(
392                        "Failed to remove migration record '{}': {}",
393                        name, e
394                    ))
395                })?;
396
397            info!("Rolled back migration: {}", name);
398            rolled_back.push(name);
399        }
400
401        Ok(rolled_back)
402    }
403
404    /// Get the status of all migrations.
405    pub async fn status(&self, available: &[Migration]) -> Result<MigrationStatus> {
406        self.ensure_migrations_table().await?;
407
408        let applied = self.get_applied_migrations().await?;
409
410        let applied_list: Vec<AppliedMigration> = {
411            let rows: Vec<(String, chrono::DateTime<chrono::Utc>, Option<String>)> =
412                sqlx::query_as(
413                    "SELECT name, applied_at, down_sql FROM forge_migrations ORDER BY id ASC",
414                )
415                .fetch_all(&self.pool)
416                .await
417                .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
418
419            rows.into_iter()
420                .map(|(name, applied_at, down_sql)| AppliedMigration {
421                    name,
422                    applied_at,
423                    has_down: down_sql.is_some(),
424                })
425                .collect()
426        };
427
428        let pending: Vec<String> = available
429            .iter()
430            .filter(|m| !applied.contains(&m.name))
431            .map(|m| m.name.clone())
432            .collect();
433
434        Ok(MigrationStatus {
435            applied: applied_list,
436            pending,
437        })
438    }
439}
440
441/// Information about an applied migration.
442#[derive(Debug, Clone)]
443pub struct AppliedMigration {
444    pub name: String,
445    pub applied_at: chrono::DateTime<chrono::Utc>,
446    pub has_down: bool,
447}
448
449/// Status of migrations.
450#[derive(Debug, Clone)]
451pub struct MigrationStatus {
452    pub applied: Vec<AppliedMigration>,
453    pub pending: Vec<String>,
454}
455
456/// Split SQL into individual statements, respecting dollar-quoted strings.
457/// This handles PL/pgSQL functions that contain semicolons inside $$ delimiters.
458fn split_sql_statements(sql: &str) -> Vec<String> {
459    let mut statements = Vec::new();
460    let mut current = String::new();
461    let mut in_dollar_quote = false;
462    let mut dollar_tag = String::new();
463    let mut chars = sql.chars().peekable();
464
465    while let Some(c) = chars.next() {
466        current.push(c);
467
468        // Check for dollar-quoting start/end
469        if c == '$' {
470            // Look for a dollar-quote tag like $$ or $tag$
471            let mut potential_tag = String::from("$");
472
473            // Collect characters until we hit another $ or non-identifier char
474            while let Some(&next_c) = chars.peek() {
475                if next_c == '$' {
476                    // Safe: peek confirmed the char exists
477                    potential_tag.push(chars.next().expect("peeked char"));
478                    current.push('$');
479                    break;
480                } else if next_c.is_alphanumeric() || next_c == '_' {
481                    let c = chars.next().expect("peeked char");
482                    potential_tag.push(c);
483                    current.push(c);
484                } else {
485                    break;
486                }
487            }
488
489            // Check if this is a valid dollar-quote delimiter (ends with $)
490            if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
491                if in_dollar_quote && potential_tag == dollar_tag {
492                    // End of dollar-quoted string
493                    in_dollar_quote = false;
494                    dollar_tag.clear();
495                } else if !in_dollar_quote {
496                    // Start of dollar-quoted string
497                    in_dollar_quote = true;
498                    dollar_tag = potential_tag;
499                }
500            }
501        }
502
503        // Split on semicolon only if not inside a dollar-quoted string
504        if c == ';' && !in_dollar_quote {
505            let stmt = current.trim().trim_end_matches(';').trim().to_string();
506            if !stmt.is_empty() {
507                statements.push(stmt);
508            }
509            current.clear();
510        }
511    }
512
513    // Don't forget the last statement (might not end with ;)
514    let stmt = current.trim().trim_end_matches(';').trim().to_string();
515    if !stmt.is_empty() {
516        statements.push(stmt);
517    }
518
519    statements
520}
521
522/// Load user migrations from a directory.
523///
524/// Migrations should be named like:
525/// - `0001_create_users.sql`
526/// - `0002_add_posts.sql`
527///
528/// They are sorted alphabetically and executed in order.
529pub fn load_migrations_from_dir(dir: &Path) -> Result<Vec<Migration>> {
530    if !dir.exists() {
531        debug!("Migrations directory does not exist: {:?}", dir);
532        return Ok(Vec::new());
533    }
534
535    let mut migrations = Vec::new();
536
537    let entries = std::fs::read_dir(dir).map_err(ForgeError::Io)?;
538
539    for entry in entries {
540        let entry = entry.map_err(ForgeError::Io)?;
541        let path = entry.path();
542
543        if path.extension().map(|e| e == "sql").unwrap_or(false) {
544            let name = path
545                .file_stem()
546                .and_then(|s| s.to_str())
547                .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
548                .to_string();
549
550            let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
551
552            migrations.push(Migration::parse(name, &content));
553        }
554    }
555
556    // Sort by name (which includes the numeric prefix)
557    migrations.sort_by(|a, b| a.name.cmp(&b.name));
558
559    debug!("Loaded {} user migrations", migrations.len());
560    Ok(migrations)
561}
562
563#[cfg(test)]
564#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
565mod tests {
566    use super::*;
567    use std::fs;
568    use tempfile::TempDir;
569
570    #[test]
571    fn test_load_migrations_from_empty_dir() {
572        let dir = TempDir::new().unwrap();
573        let migrations = load_migrations_from_dir(dir.path()).unwrap();
574        assert!(migrations.is_empty());
575    }
576
577    #[test]
578    fn test_load_migrations_from_nonexistent_dir() {
579        let migrations = load_migrations_from_dir(Path::new("/nonexistent/path")).unwrap();
580        assert!(migrations.is_empty());
581    }
582
583    #[test]
584    fn test_load_migrations_sorted() {
585        let dir = TempDir::new().unwrap();
586
587        // Create migrations out of order
588        fs::write(dir.path().join("0002_second.sql"), "SELECT 2;").unwrap();
589        fs::write(dir.path().join("0001_first.sql"), "SELECT 1;").unwrap();
590        fs::write(dir.path().join("0003_third.sql"), "SELECT 3;").unwrap();
591
592        let migrations = load_migrations_from_dir(dir.path()).unwrap();
593        assert_eq!(migrations.len(), 3);
594        assert_eq!(migrations[0].name, "0001_first");
595        assert_eq!(migrations[1].name, "0002_second");
596        assert_eq!(migrations[2].name, "0003_third");
597    }
598
599    #[test]
600    fn test_load_migrations_ignores_non_sql() {
601        let dir = TempDir::new().unwrap();
602
603        fs::write(dir.path().join("0001_migration.sql"), "SELECT 1;").unwrap();
604        fs::write(dir.path().join("readme.txt"), "Not a migration").unwrap();
605        fs::write(dir.path().join("backup.sql.bak"), "Backup").unwrap();
606
607        let migrations = load_migrations_from_dir(dir.path()).unwrap();
608        assert_eq!(migrations.len(), 1);
609        assert_eq!(migrations[0].name, "0001_migration");
610    }
611
612    #[test]
613    fn test_migration_new() {
614        let m = Migration::new("test", "SELECT 1");
615        assert_eq!(m.name, "test");
616        assert_eq!(m.up_sql, "SELECT 1");
617        assert!(m.down_sql.is_none());
618    }
619
620    #[test]
621    fn test_migration_with_down() {
622        let m = Migration::with_down("test", "CREATE TABLE t()", "DROP TABLE t");
623        assert_eq!(m.name, "test");
624        assert_eq!(m.up_sql, "CREATE TABLE t()");
625        assert_eq!(m.down_sql, Some("DROP TABLE t".to_string()));
626    }
627
628    #[test]
629    fn test_migration_parse_up_only() {
630        let content = "CREATE TABLE users (id INT);";
631        let m = Migration::parse("0001_test", content);
632        assert_eq!(m.name, "0001_test");
633        assert_eq!(m.up_sql, "CREATE TABLE users (id INT);");
634        assert!(m.down_sql.is_none());
635    }
636
637    #[test]
638    fn test_migration_parse_with_markers() {
639        let content = r#"
640-- @up
641CREATE TABLE users (
642    id UUID PRIMARY KEY,
643    email VARCHAR(255)
644);
645
646-- @down
647DROP TABLE users;
648"#;
649        let m = Migration::parse("0001_users", content);
650        assert_eq!(m.name, "0001_users");
651        assert!(m.up_sql.contains("CREATE TABLE users"));
652        assert!(!m.up_sql.contains("@up"));
653        assert!(!m.up_sql.contains("DROP TABLE"));
654        assert_eq!(m.down_sql, Some("DROP TABLE users;".to_string()));
655    }
656
657    #[test]
658    fn test_migration_parse_complex() {
659        let content = r#"
660-- @up
661CREATE TABLE posts (
662    id UUID PRIMARY KEY,
663    title TEXT NOT NULL
664);
665CREATE INDEX idx_posts_title ON posts(title);
666
667-- @down
668DROP INDEX idx_posts_title;
669DROP TABLE posts;
670"#;
671        let m = Migration::parse("0002_posts", content);
672        assert!(m.up_sql.contains("CREATE TABLE posts"));
673        assert!(m.up_sql.contains("CREATE INDEX"));
674        let down = m.down_sql.unwrap();
675        assert!(down.contains("DROP INDEX"));
676        assert!(down.contains("DROP TABLE posts"));
677    }
678
679    #[test]
680    fn test_split_simple_statements() {
681        let sql = "SELECT 1; SELECT 2; SELECT 3;";
682        let stmts = super::split_sql_statements(sql);
683        assert_eq!(stmts.len(), 3);
684        assert_eq!(stmts[0], "SELECT 1");
685        assert_eq!(stmts[1], "SELECT 2");
686        assert_eq!(stmts[2], "SELECT 3");
687    }
688
689    #[test]
690    fn test_split_with_dollar_quoted_function() {
691        let sql = r#"
692CREATE FUNCTION test() RETURNS void AS $$
693BEGIN
694    SELECT 1;
695    SELECT 2;
696END;
697$$ LANGUAGE plpgsql;
698
699SELECT 3;
700"#;
701        let stmts = super::split_sql_statements(sql);
702        assert_eq!(stmts.len(), 2);
703        assert!(stmts[0].contains("CREATE FUNCTION"));
704        assert!(stmts[0].contains("$$ LANGUAGE plpgsql"));
705        assert!(stmts[1].contains("SELECT 3"));
706    }
707
708    #[test]
709    fn test_split_preserves_dollar_quote_content() {
710        let sql = r#"
711CREATE FUNCTION notify() RETURNS trigger AS $$
712DECLARE
713    row_id TEXT;
714BEGIN
715    row_id := NEW.id::TEXT;
716    RETURN NEW;
717END;
718$$ LANGUAGE plpgsql;
719"#;
720        let stmts = super::split_sql_statements(sql);
721        assert_eq!(stmts.len(), 1);
722        assert!(stmts[0].contains("row_id := NEW.id::TEXT"));
723    }
724}