Skip to main content

forge_runtime/migrations/
runner.rs

1//! Migration runner with mesh-safe locking.
2//!
3//! Ensures only one node runs migrations at a time using PostgreSQL advisory locks.
4//!
5//! # Migration Types
6//!
7//! This runner handles two types of migrations:
8//!
9//! 1. **System migrations** (`__forge_vXXX`): Internal FORGE schema changes.
10//!    These are versioned numerically and always run before user migrations.
11//!    Legacy installations may have `0000_forge_internal` which is treated as v001.
12//!
13//! 2. **User migrations** (`XXXX_name.sql`): Application-specific schema changes.
14//!    These are sorted alphabetically by name.
15
16use forge_core::error::{ForgeError, Result};
17use sqlx::{PgPool, Postgres};
18use std::collections::HashSet;
19use std::path::Path;
20use tracing::{debug, info, warn};
21
22use super::builtin::extract_version;
23
24/// Lock ID for migration advisory lock (arbitrary but consistent).
25/// Using a fixed value derived from "FORGE" ascii values.
26const MIGRATION_LOCK_ID: i64 = 0x464F524745; // "FORGE" in hex
27
28/// A single migration with up and optional down SQL.
29#[derive(Debug, Clone)]
30pub struct Migration {
31    /// Unique name/identifier (e.g., "0001_forge_internal" or "0002_create_users").
32    pub name: String,
33    /// SQL to execute for upgrade (forward migration).
34    pub up_sql: String,
35    /// SQL to execute for rollback (optional).
36    pub down_sql: Option<String>,
37}
38
39impl Migration {
40    /// Create a migration with only up SQL (no rollback).
41    pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
42        Self {
43            name: name.into(),
44            up_sql: sql.into(),
45            down_sql: None,
46        }
47    }
48
49    /// Create a migration with both up and down SQL.
50    pub fn with_down(
51        name: impl Into<String>,
52        up_sql: impl Into<String>,
53        down_sql: impl Into<String>,
54    ) -> Self {
55        Self {
56            name: name.into(),
57            up_sql: up_sql.into(),
58            down_sql: Some(down_sql.into()),
59        }
60    }
61
62    /// Parse migration content that may contain -- @up and -- @down markers.
63    pub fn parse(name: impl Into<String>, content: &str) -> Self {
64        let name = name.into();
65        let (up_sql, down_sql) = parse_migration_content(content);
66        Self {
67            name,
68            up_sql,
69            down_sql,
70        }
71    }
72}
73
74/// Parse migration content, splitting on -- @down marker.
75/// Returns (up_sql, Option<down_sql>).
76fn parse_migration_content(content: &str) -> (String, Option<String>) {
77    // Look for -- @down marker (case insensitive, with optional whitespace)
78    let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
79
80    for pattern in down_marker_patterns {
81        if let Some(idx) = content.find(pattern) {
82            let up_part = &content[..idx];
83            let down_part = &content[idx + pattern.len()..];
84
85            // Clean up the up part (remove -- @up marker if present)
86            let up_sql = up_part
87                .replace("-- @up", "")
88                .replace("--@up", "")
89                .replace("-- @UP", "")
90                .replace("--@UP", "")
91                .trim()
92                .to_string();
93
94            let down_sql = down_part.trim().to_string();
95
96            if down_sql.is_empty() {
97                return (up_sql, None);
98            }
99            return (up_sql, Some(down_sql));
100        }
101    }
102
103    // No @down marker found - treat entire content as up SQL
104    let up_sql = content
105        .replace("-- @up", "")
106        .replace("--@up", "")
107        .replace("-- @UP", "")
108        .replace("--@UP", "")
109        .trim()
110        .to_string();
111
112    (up_sql, None)
113}
114
115/// Migration runner that handles both built-in and user migrations.
116pub struct MigrationRunner {
117    pool: PgPool,
118}
119
120impl MigrationRunner {
121    pub fn new(pool: PgPool) -> Self {
122        Self { pool }
123    }
124
125    /// Run all pending migrations with mesh-safe locking.
126    ///
127    /// This acquires an exclusive advisory lock before running migrations,
128    /// ensuring only one node in the cluster runs migrations at a time.
129    pub async fn run(&self, user_migrations: Vec<Migration>) -> Result<()> {
130        // Acquire exclusive lock (blocks until acquired) on a dedicated connection.
131        let mut lock_conn = self.acquire_lock_connection().await?;
132
133        let result = self.run_migrations_inner(user_migrations).await;
134
135        // Always release lock, even on error
136        if let Err(e) = self.release_lock_connection(&mut lock_conn).await {
137            warn!("Failed to release migration lock: {}", e);
138        }
139
140        result
141    }
142
143    async fn run_migrations_inner(&self, user_migrations: Vec<Migration>) -> Result<()> {
144        // Ensure migration tracking table exists
145        self.ensure_migrations_table().await?;
146
147        // Get already-applied migrations
148        let applied = self.get_applied_migrations().await?;
149        debug!("Already applied migrations: {:?}", applied);
150
151        // Calculate the highest system version already applied
152        let max_applied_version = self.get_max_system_version(&applied);
153        debug!("Max applied system version: {:?}", max_applied_version);
154
155        // Run built-in FORGE system migrations first (in version order)
156        let system_migrations = super::builtin::get_system_migrations();
157        for sys_migration in system_migrations {
158            // Skip if this version (or equivalent legacy) is already applied
159            if let Some(max_ver) = max_applied_version
160                && sys_migration.version <= max_ver
161            {
162                debug!(
163                    "Skipping system migration v{} (already at v{})",
164                    sys_migration.version, max_ver
165                );
166                continue;
167            }
168
169            let migration = sys_migration.to_migration();
170            info!(
171                "Applying system migration: {} ({})",
172                migration.name, sys_migration.description
173            );
174            self.apply_migration(&migration).await?;
175        }
176
177        // Then run user migrations (sorted by name)
178        for migration in user_migrations {
179            if !applied.contains(&migration.name) {
180                self.apply_migration(&migration).await?;
181            }
182        }
183
184        Ok(())
185    }
186
187    /// Get the maximum system migration version that has been applied.
188    /// Considers both new-style `__forge_vXXX` and legacy `0000_forge_internal`.
189    fn get_max_system_version(&self, applied: &HashSet<String>) -> Option<u32> {
190        applied
191            .iter()
192            .filter_map(|name| extract_version(name))
193            .max()
194    }
195
196    async fn acquire_lock_connection(&self) -> Result<sqlx::pool::PoolConnection<Postgres>> {
197        debug!("Acquiring migration lock...");
198        let mut conn = self.pool.acquire().await.map_err(|e| {
199            ForgeError::Database(format!("Failed to acquire lock connection: {}", e))
200        })?;
201
202        sqlx::query("SELECT pg_advisory_lock($1)")
203            .bind(MIGRATION_LOCK_ID)
204            .execute(&mut *conn)
205            .await
206            .map_err(|e| {
207                ForgeError::Database(format!("Failed to acquire migration lock: {}", e))
208            })?;
209        debug!("Migration lock acquired");
210        Ok(conn)
211    }
212
213    async fn release_lock_connection(
214        &self,
215        conn: &mut sqlx::pool::PoolConnection<Postgres>,
216    ) -> Result<()> {
217        sqlx::query("SELECT pg_advisory_unlock($1)")
218            .bind(MIGRATION_LOCK_ID)
219            .execute(&mut **conn)
220            .await
221            .map_err(|e| {
222                ForgeError::Database(format!("Failed to release migration lock: {}", e))
223            })?;
224        debug!("Migration lock released");
225        Ok(())
226    }
227
228    async fn ensure_migrations_table(&self) -> Result<()> {
229        // Create table if not exists
230        sqlx::query(
231            r#"
232            CREATE TABLE IF NOT EXISTS forge_migrations (
233                id SERIAL PRIMARY KEY,
234                name VARCHAR(255) UNIQUE NOT NULL,
235                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
236                down_sql TEXT
237            )
238            "#,
239        )
240        .execute(&self.pool)
241        .await
242        .map_err(|e| ForgeError::Database(format!("Failed to create migrations table: {}", e)))?;
243
244        // Add down_sql column if it doesn't exist (for existing installations)
245        sqlx::query(
246            r#"
247            ALTER TABLE forge_migrations
248            ADD COLUMN IF NOT EXISTS down_sql TEXT
249            "#,
250        )
251        .execute(&self.pool)
252        .await
253        .map_err(|e| ForgeError::Database(format!("Failed to add down_sql column: {}", e)))?;
254
255        Ok(())
256    }
257
258    async fn get_applied_migrations(&self) -> Result<HashSet<String>> {
259        let rows = sqlx::query!("SELECT name FROM forge_migrations")
260            .fetch_all(&self.pool)
261            .await
262            .map_err(|e| {
263                ForgeError::Database(format!("Failed to get applied migrations: {}", e))
264            })?;
265
266        Ok(rows.into_iter().map(|row| row.name).collect())
267    }
268
269    async fn apply_migration(&self, migration: &Migration) -> Result<()> {
270        info!("Applying migration: {}", migration.name);
271
272        // Split migration into individual statements, respecting dollar-quoted strings
273        let statements = split_sql_statements(&migration.up_sql);
274
275        for statement in statements {
276            let statement = statement.trim();
277
278            // Skip empty statements or comment-only blocks
279            if statement.is_empty()
280                || statement.lines().all(|l| {
281                    let l = l.trim();
282                    l.is_empty() || l.starts_with("--")
283                })
284            {
285                continue;
286            }
287
288            sqlx::query(statement)
289                .execute(&self.pool)
290                .await
291                .map_err(|e| {
292                    ForgeError::Database(format!(
293                        "Failed to apply migration '{}': {}",
294                        migration.name, e
295                    ))
296                })?;
297        }
298
299        // Record it as applied (with down_sql for potential rollback)
300        sqlx::query("INSERT INTO forge_migrations (name, down_sql) VALUES ($1, $2)")
301            .bind(&migration.name)
302            .bind(&migration.down_sql)
303            .execute(&self.pool)
304            .await
305            .map_err(|e| {
306                ForgeError::Database(format!(
307                    "Failed to record migration '{}': {}",
308                    migration.name, e
309                ))
310            })?;
311
312        info!("Migration applied: {}", migration.name);
313        Ok(())
314    }
315
316    /// Rollback N migrations (most recent first).
317    pub async fn rollback(&self, count: usize) -> Result<Vec<String>> {
318        if count == 0 {
319            return Ok(Vec::new());
320        }
321
322        // Acquire exclusive lock on a dedicated connection.
323        let mut lock_conn = self.acquire_lock_connection().await?;
324
325        let result = self.rollback_inner(count).await;
326
327        // Always release lock
328        if let Err(e) = self.release_lock_connection(&mut lock_conn).await {
329            warn!("Failed to release migration lock: {}", e);
330        }
331
332        result
333    }
334
335    async fn rollback_inner(&self, count: usize) -> Result<Vec<String>> {
336        self.ensure_migrations_table().await?;
337
338        // Get the N most recent migrations with their down_sql
339        let rows = sqlx::query!(
340            "SELECT id, name, down_sql FROM forge_migrations ORDER BY id DESC LIMIT $1",
341            count as i32
342        )
343        .fetch_all(&self.pool)
344        .await
345        .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
346
347        if rows.is_empty() {
348            info!("No migrations to rollback");
349            return Ok(Vec::new());
350        }
351
352        let mut rolled_back = Vec::new();
353
354        for row in rows {
355            let id = row.id;
356            let name = row.name;
357            let down_sql = row.down_sql;
358            info!("Rolling back migration: {}", name);
359
360            if let Some(down) = down_sql {
361                // Execute down SQL
362                let statements = split_sql_statements(&down);
363                for statement in statements {
364                    let statement = statement.trim();
365                    if statement.is_empty()
366                        || statement.lines().all(|l| {
367                            let l = l.trim();
368                            l.is_empty() || l.starts_with("--")
369                        })
370                    {
371                        continue;
372                    }
373
374                    sqlx::query(statement)
375                        .execute(&self.pool)
376                        .await
377                        .map_err(|e| {
378                            ForgeError::Database(format!(
379                                "Failed to rollback migration '{}': {}",
380                                name, e
381                            ))
382                        })?;
383                }
384            } else {
385                warn!("Migration '{}' has no down SQL, removing record only", name);
386            }
387
388            // Remove from migrations table
389            sqlx::query("DELETE FROM forge_migrations WHERE id = $1")
390                .bind(id)
391                .execute(&self.pool)
392                .await
393                .map_err(|e| {
394                    ForgeError::Database(format!(
395                        "Failed to remove migration record '{}': {}",
396                        name, e
397                    ))
398                })?;
399
400            info!("Rolled back migration: {}", name);
401            rolled_back.push(name);
402        }
403
404        Ok(rolled_back)
405    }
406
407    /// Get the status of all migrations.
408    pub async fn status(&self, available: &[Migration]) -> Result<MigrationStatus> {
409        self.ensure_migrations_table().await?;
410
411        let applied = self.get_applied_migrations().await?;
412
413        let applied_list: Vec<AppliedMigration> = {
414            let rows = sqlx::query!(
415                "SELECT name, applied_at, down_sql FROM forge_migrations ORDER BY id ASC"
416            )
417            .fetch_all(&self.pool)
418            .await
419            .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
420
421            rows.into_iter()
422                .map(|row| AppliedMigration {
423                    name: row.name,
424                    applied_at: row.applied_at,
425                    has_down: row.down_sql.is_some(),
426                })
427                .collect()
428        };
429
430        let pending: Vec<String> = available
431            .iter()
432            .filter(|m| !applied.contains(&m.name))
433            .map(|m| m.name.clone())
434            .collect();
435
436        Ok(MigrationStatus {
437            applied: applied_list,
438            pending,
439        })
440    }
441}
442
443/// Information about an applied migration.
444#[derive(Debug, Clone)]
445pub struct AppliedMigration {
446    pub name: String,
447    pub applied_at: chrono::DateTime<chrono::Utc>,
448    pub has_down: bool,
449}
450
451/// Status of migrations.
452#[derive(Debug, Clone)]
453pub struct MigrationStatus {
454    pub applied: Vec<AppliedMigration>,
455    pub pending: Vec<String>,
456}
457
458/// Split SQL into individual statements, respecting dollar-quoted strings.
459/// This handles PL/pgSQL functions that contain semicolons inside $$ delimiters.
460fn split_sql_statements(sql: &str) -> Vec<String> {
461    let mut statements = Vec::new();
462    let mut current = String::new();
463    let mut in_dollar_quote = false;
464    let mut dollar_tag = String::new();
465    let mut chars = sql.chars().peekable();
466
467    while let Some(c) = chars.next() {
468        current.push(c);
469
470        // Check for dollar-quoting start/end
471        if c == '$' {
472            // Look for a dollar-quote tag like $$ or $tag$
473            let mut potential_tag = String::from("$");
474
475            // Collect characters until we hit another $ or non-identifier char
476            while let Some(&next_c) = chars.peek() {
477                if next_c == '$' {
478                    // Safe: peek confirmed the char exists
479                    potential_tag.push(chars.next().expect("peeked char"));
480                    current.push('$');
481                    break;
482                } else if next_c.is_alphanumeric() || next_c == '_' {
483                    let c = chars.next().expect("peeked char");
484                    potential_tag.push(c);
485                    current.push(c);
486                } else {
487                    break;
488                }
489            }
490
491            // Check if this is a valid dollar-quote delimiter (ends with $)
492            if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
493                if in_dollar_quote && potential_tag == dollar_tag {
494                    // End of dollar-quoted string
495                    in_dollar_quote = false;
496                    dollar_tag.clear();
497                } else if !in_dollar_quote {
498                    // Start of dollar-quoted string
499                    in_dollar_quote = true;
500                    dollar_tag = potential_tag;
501                }
502            }
503        }
504
505        // Split on semicolon only if not inside a dollar-quoted string
506        if c == ';' && !in_dollar_quote {
507            let stmt = current.trim().trim_end_matches(';').trim().to_string();
508            if !stmt.is_empty() {
509                statements.push(stmt);
510            }
511            current.clear();
512        }
513    }
514
515    // Don't forget the last statement (might not end with ;)
516    let stmt = current.trim().trim_end_matches(';').trim().to_string();
517    if !stmt.is_empty() {
518        statements.push(stmt);
519    }
520
521    statements
522}
523
524/// Load user migrations from a directory.
525///
526/// Migrations should be named like:
527/// - `0001_create_users.sql`
528/// - `0002_add_posts.sql`
529///
530/// They are sorted alphabetically and executed in order.
531pub fn load_migrations_from_dir(dir: &Path) -> Result<Vec<Migration>> {
532    if !dir.exists() {
533        debug!("Migrations directory does not exist: {:?}", dir);
534        return Ok(Vec::new());
535    }
536
537    let mut migrations = Vec::new();
538
539    let entries = std::fs::read_dir(dir).map_err(ForgeError::Io)?;
540
541    for entry in entries {
542        let entry = entry.map_err(ForgeError::Io)?;
543        let path = entry.path();
544
545        if path.extension().map(|e| e == "sql").unwrap_or(false) {
546            let name = path
547                .file_stem()
548                .and_then(|s| s.to_str())
549                .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
550                .to_string();
551
552            let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
553
554            migrations.push(Migration::parse(name, &content));
555        }
556    }
557
558    // Sort by name (which includes the numeric prefix)
559    migrations.sort_by(|a, b| a.name.cmp(&b.name));
560
561    debug!("Loaded {} user migrations", migrations.len());
562    Ok(migrations)
563}
564
565#[cfg(test)]
566#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
567mod tests {
568    use super::*;
569    use std::fs;
570    use tempfile::TempDir;
571
572    #[test]
573    fn test_load_migrations_from_empty_dir() {
574        let dir = TempDir::new().unwrap();
575        let migrations = load_migrations_from_dir(dir.path()).unwrap();
576        assert!(migrations.is_empty());
577    }
578
579    #[test]
580    fn test_load_migrations_from_nonexistent_dir() {
581        let migrations = load_migrations_from_dir(Path::new("/nonexistent/path")).unwrap();
582        assert!(migrations.is_empty());
583    }
584
585    #[test]
586    fn test_load_migrations_sorted() {
587        let dir = TempDir::new().unwrap();
588
589        // Create migrations out of order
590        fs::write(dir.path().join("0002_second.sql"), "SELECT 2;").unwrap();
591        fs::write(dir.path().join("0001_first.sql"), "SELECT 1;").unwrap();
592        fs::write(dir.path().join("0003_third.sql"), "SELECT 3;").unwrap();
593
594        let migrations = load_migrations_from_dir(dir.path()).unwrap();
595        assert_eq!(migrations.len(), 3);
596        assert_eq!(migrations[0].name, "0001_first");
597        assert_eq!(migrations[1].name, "0002_second");
598        assert_eq!(migrations[2].name, "0003_third");
599    }
600
601    #[test]
602    fn test_load_migrations_ignores_non_sql() {
603        let dir = TempDir::new().unwrap();
604
605        fs::write(dir.path().join("0001_migration.sql"), "SELECT 1;").unwrap();
606        fs::write(dir.path().join("readme.txt"), "Not a migration").unwrap();
607        fs::write(dir.path().join("backup.sql.bak"), "Backup").unwrap();
608
609        let migrations = load_migrations_from_dir(dir.path()).unwrap();
610        assert_eq!(migrations.len(), 1);
611        assert_eq!(migrations[0].name, "0001_migration");
612    }
613
614    #[test]
615    fn test_migration_new() {
616        let m = Migration::new("test", "SELECT 1");
617        assert_eq!(m.name, "test");
618        assert_eq!(m.up_sql, "SELECT 1");
619        assert!(m.down_sql.is_none());
620    }
621
622    #[test]
623    fn test_migration_with_down() {
624        let m = Migration::with_down("test", "CREATE TABLE t()", "DROP TABLE t");
625        assert_eq!(m.name, "test");
626        assert_eq!(m.up_sql, "CREATE TABLE t()");
627        assert_eq!(m.down_sql, Some("DROP TABLE t".to_string()));
628    }
629
630    #[test]
631    fn test_migration_parse_up_only() {
632        let content = "CREATE TABLE users (id INT);";
633        let m = Migration::parse("0001_test", content);
634        assert_eq!(m.name, "0001_test");
635        assert_eq!(m.up_sql, "CREATE TABLE users (id INT);");
636        assert!(m.down_sql.is_none());
637    }
638
639    #[test]
640    fn test_migration_parse_with_markers() {
641        let content = r#"
642-- @up
643CREATE TABLE users (
644    id UUID PRIMARY KEY,
645    email VARCHAR(255)
646);
647
648-- @down
649DROP TABLE users;
650"#;
651        let m = Migration::parse("0001_users", content);
652        assert_eq!(m.name, "0001_users");
653        assert!(m.up_sql.contains("CREATE TABLE users"));
654        assert!(!m.up_sql.contains("@up"));
655        assert!(!m.up_sql.contains("DROP TABLE"));
656        assert_eq!(m.down_sql, Some("DROP TABLE users;".to_string()));
657    }
658
659    #[test]
660    fn test_migration_parse_complex() {
661        let content = r#"
662-- @up
663CREATE TABLE posts (
664    id UUID PRIMARY KEY,
665    title TEXT NOT NULL
666);
667CREATE INDEX idx_posts_title ON posts(title);
668
669-- @down
670DROP INDEX idx_posts_title;
671DROP TABLE posts;
672"#;
673        let m = Migration::parse("0002_posts", content);
674        assert!(m.up_sql.contains("CREATE TABLE posts"));
675        assert!(m.up_sql.contains("CREATE INDEX"));
676        let down = m.down_sql.unwrap();
677        assert!(down.contains("DROP INDEX"));
678        assert!(down.contains("DROP TABLE posts"));
679    }
680
681    #[test]
682    fn test_split_simple_statements() {
683        let sql = "SELECT 1; SELECT 2; SELECT 3;";
684        let stmts = super::split_sql_statements(sql);
685        assert_eq!(stmts.len(), 3);
686        assert_eq!(stmts[0], "SELECT 1");
687        assert_eq!(stmts[1], "SELECT 2");
688        assert_eq!(stmts[2], "SELECT 3");
689    }
690
691    #[test]
692    fn test_split_with_dollar_quoted_function() {
693        let sql = r#"
694CREATE FUNCTION test() RETURNS void AS $$
695BEGIN
696    SELECT 1;
697    SELECT 2;
698END;
699$$ LANGUAGE plpgsql;
700
701SELECT 3;
702"#;
703        let stmts = super::split_sql_statements(sql);
704        assert_eq!(stmts.len(), 2);
705        assert!(stmts[0].contains("CREATE FUNCTION"));
706        assert!(stmts[0].contains("$$ LANGUAGE plpgsql"));
707        assert!(stmts[1].contains("SELECT 3"));
708    }
709
710    #[test]
711    fn test_split_preserves_dollar_quote_content() {
712        let sql = r#"
713CREATE FUNCTION notify() RETURNS trigger AS $$
714DECLARE
715    row_id TEXT;
716BEGIN
717    row_id := NEW.id::TEXT;
718    RETURN NEW;
719END;
720$$ LANGUAGE plpgsql;
721"#;
722        let stmts = super::split_sql_statements(sql);
723        assert_eq!(stmts.len(), 1);
724        assert!(stmts[0].contains("row_id := NEW.id::TEXT"));
725    }
726}