Skip to main content

forge_runtime/migrations/
runner.rs

1//! Migration runner with mesh-safe locking.
2//!
3//! Ensures only one node runs migrations at a time using PostgreSQL advisory locks.
4//!
5//! # Migration Types
6//!
7//! This runner handles two types of migrations:
8//!
9//! 1. **System migrations** (`__forge_vXXX`): Internal FORGE schema changes.
10//!    These are versioned numerically and always run before user migrations.
11//!    Legacy installations may have `0000_forge_internal` which is treated as v001.
12//!
13//! 2. **User migrations** (`XXXX_name.sql`): Application-specific schema changes.
14//!    These are sorted alphabetically by name.
15
16use forge_core::error::{ForgeError, Result};
17use sqlx::PgPool;
18use std::collections::HashSet;
19use std::path::Path;
20use tracing::{debug, info, warn};
21
22use super::builtin::extract_version;
23
24/// Lock ID for migration advisory lock (arbitrary but consistent).
25/// Using a fixed value derived from "FORGE" ascii values.
26const MIGRATION_LOCK_ID: i64 = 0x464F524745; // "FORGE" in hex
27
28/// A single migration with up and optional down SQL.
29#[derive(Debug, Clone)]
30pub struct Migration {
31    /// Unique name/identifier (e.g., "0001_forge_internal" or "0002_create_users").
32    pub name: String,
33    /// SQL to execute for upgrade (forward migration).
34    pub up_sql: String,
35    /// SQL to execute for rollback (optional).
36    pub down_sql: Option<String>,
37}
38
39impl Migration {
40    /// Create a migration with only up SQL (no rollback).
41    pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
42        Self {
43            name: name.into(),
44            up_sql: sql.into(),
45            down_sql: None,
46        }
47    }
48
49    /// Create a migration with both up and down SQL.
50    pub fn with_down(
51        name: impl Into<String>,
52        up_sql: impl Into<String>,
53        down_sql: impl Into<String>,
54    ) -> Self {
55        Self {
56            name: name.into(),
57            up_sql: up_sql.into(),
58            down_sql: Some(down_sql.into()),
59        }
60    }
61
62    /// Parse migration content that may contain -- @up and -- @down markers.
63    pub fn parse(name: impl Into<String>, content: &str) -> Self {
64        let name = name.into();
65        let (up_sql, down_sql) = parse_migration_content(content);
66        Self {
67            name,
68            up_sql,
69            down_sql,
70        }
71    }
72}
73
74/// Parse migration content, splitting on -- @down marker.
75/// Returns (up_sql, Option<down_sql>).
76fn parse_migration_content(content: &str) -> (String, Option<String>) {
77    // Look for -- @down marker (case insensitive, with optional whitespace)
78    let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
79
80    for pattern in down_marker_patterns {
81        if let Some(idx) = content.find(pattern) {
82            let up_part = &content[..idx];
83            let down_part = &content[idx + pattern.len()..];
84
85            // Clean up the up part (remove -- @up marker if present)
86            let up_sql = up_part
87                .replace("-- @up", "")
88                .replace("--@up", "")
89                .replace("-- @UP", "")
90                .replace("--@UP", "")
91                .trim()
92                .to_string();
93
94            let down_sql = down_part.trim().to_string();
95
96            if down_sql.is_empty() {
97                return (up_sql, None);
98            }
99            return (up_sql, Some(down_sql));
100        }
101    }
102
103    // No @down marker found - treat entire content as up SQL
104    let up_sql = content
105        .replace("-- @up", "")
106        .replace("--@up", "")
107        .replace("-- @UP", "")
108        .replace("--@UP", "")
109        .trim()
110        .to_string();
111
112    (up_sql, None)
113}
114
115/// Migration runner that handles both built-in and user migrations.
116pub struct MigrationRunner {
117    pool: PgPool,
118}
119
120impl MigrationRunner {
121    pub fn new(pool: PgPool) -> Self {
122        Self { pool }
123    }
124
125    /// Run all pending migrations with mesh-safe locking.
126    ///
127    /// This acquires an exclusive advisory lock before running migrations,
128    /// ensuring only one node in the cluster runs migrations at a time.
129    pub async fn run(&self, user_migrations: Vec<Migration>) -> Result<()> {
130        // Acquire exclusive lock (blocks until acquired)
131        self.acquire_lock().await?;
132
133        let result = self.run_migrations_inner(user_migrations).await;
134
135        // Always release lock, even on error
136        if let Err(e) = self.release_lock().await {
137            warn!("Failed to release migration lock: {}", e);
138        }
139
140        result
141    }
142
143    async fn run_migrations_inner(&self, user_migrations: Vec<Migration>) -> Result<()> {
144        // Ensure migration tracking table exists
145        self.ensure_migrations_table().await?;
146
147        // Get already-applied migrations
148        let applied = self.get_applied_migrations().await?;
149        debug!("Already applied migrations: {:?}", applied);
150
151        // Calculate the highest system version already applied
152        let max_applied_version = self.get_max_system_version(&applied);
153        debug!("Max applied system version: {:?}", max_applied_version);
154
155        // Run built-in FORGE system migrations first (in version order)
156        let system_migrations = super::builtin::get_system_migrations();
157        for sys_migration in system_migrations {
158            // Skip if this version (or equivalent legacy) is already applied
159            if let Some(max_ver) = max_applied_version {
160                if sys_migration.version <= max_ver {
161                    debug!(
162                        "Skipping system migration v{} (already at v{})",
163                        sys_migration.version, max_ver
164                    );
165                    continue;
166                }
167            }
168
169            let migration = sys_migration.to_migration();
170            info!(
171                "Applying system migration: {} ({})",
172                migration.name, sys_migration.description
173            );
174            self.apply_migration(&migration).await?;
175        }
176
177        // Then run user migrations (sorted by name)
178        for migration in user_migrations {
179            if !applied.contains(&migration.name) {
180                self.apply_migration(&migration).await?;
181            }
182        }
183
184        Ok(())
185    }
186
187    /// Get the maximum system migration version that has been applied.
188    /// Considers both new-style `__forge_vXXX` and legacy `0000_forge_internal`.
189    fn get_max_system_version(&self, applied: &HashSet<String>) -> Option<u32> {
190        applied
191            .iter()
192            .filter_map(|name| extract_version(name))
193            .max()
194    }
195
196    async fn acquire_lock(&self) -> Result<()> {
197        debug!("Acquiring migration lock...");
198        sqlx::query("SELECT pg_advisory_lock($1)")
199            .bind(MIGRATION_LOCK_ID)
200            .execute(&self.pool)
201            .await
202            .map_err(|e| {
203                ForgeError::Database(format!("Failed to acquire migration lock: {}", e))
204            })?;
205        debug!("Migration lock acquired");
206        Ok(())
207    }
208
209    async fn release_lock(&self) -> Result<()> {
210        sqlx::query("SELECT pg_advisory_unlock($1)")
211            .bind(MIGRATION_LOCK_ID)
212            .execute(&self.pool)
213            .await
214            .map_err(|e| {
215                ForgeError::Database(format!("Failed to release migration lock: {}", e))
216            })?;
217        debug!("Migration lock released");
218        Ok(())
219    }
220
221    async fn ensure_migrations_table(&self) -> Result<()> {
222        // Create table if not exists
223        sqlx::query(
224            r#"
225            CREATE TABLE IF NOT EXISTS forge_migrations (
226                id SERIAL PRIMARY KEY,
227                name VARCHAR(255) UNIQUE NOT NULL,
228                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
229                down_sql TEXT
230            )
231            "#,
232        )
233        .execute(&self.pool)
234        .await
235        .map_err(|e| ForgeError::Database(format!("Failed to create migrations table: {}", e)))?;
236
237        // Add down_sql column if it doesn't exist (for existing installations)
238        sqlx::query(
239            r#"
240            ALTER TABLE forge_migrations
241            ADD COLUMN IF NOT EXISTS down_sql TEXT
242            "#,
243        )
244        .execute(&self.pool)
245        .await
246        .map_err(|e| ForgeError::Database(format!("Failed to add down_sql column: {}", e)))?;
247
248        Ok(())
249    }
250
251    async fn get_applied_migrations(&self) -> Result<HashSet<String>> {
252        let rows: Vec<(String,)> = sqlx::query_as("SELECT name FROM forge_migrations")
253            .fetch_all(&self.pool)
254            .await
255            .map_err(|e| {
256                ForgeError::Database(format!("Failed to get applied migrations: {}", e))
257            })?;
258
259        Ok(rows.into_iter().map(|(name,)| name).collect())
260    }
261
262    async fn apply_migration(&self, migration: &Migration) -> Result<()> {
263        info!("Applying migration: {}", migration.name);
264
265        // Split migration into individual statements, respecting dollar-quoted strings
266        let statements = split_sql_statements(&migration.up_sql);
267
268        for statement in statements {
269            let statement = statement.trim();
270
271            // Skip empty statements or comment-only blocks
272            if statement.is_empty()
273                || statement.lines().all(|l| {
274                    let l = l.trim();
275                    l.is_empty() || l.starts_with("--")
276                })
277            {
278                continue;
279            }
280
281            sqlx::query(statement)
282                .execute(&self.pool)
283                .await
284                .map_err(|e| {
285                    ForgeError::Database(format!(
286                        "Failed to apply migration '{}': {}",
287                        migration.name, e
288                    ))
289                })?;
290        }
291
292        // Record it as applied (with down_sql for potential rollback)
293        sqlx::query("INSERT INTO forge_migrations (name, down_sql) VALUES ($1, $2)")
294            .bind(&migration.name)
295            .bind(&migration.down_sql)
296            .execute(&self.pool)
297            .await
298            .map_err(|e| {
299                ForgeError::Database(format!(
300                    "Failed to record migration '{}': {}",
301                    migration.name, e
302                ))
303            })?;
304
305        info!("Migration applied: {}", migration.name);
306        Ok(())
307    }
308
309    /// Rollback N migrations (most recent first).
310    pub async fn rollback(&self, count: usize) -> Result<Vec<String>> {
311        if count == 0 {
312            return Ok(Vec::new());
313        }
314
315        // Acquire exclusive lock
316        self.acquire_lock().await?;
317
318        let result = self.rollback_inner(count).await;
319
320        // Always release lock
321        if let Err(e) = self.release_lock().await {
322            warn!("Failed to release migration lock: {}", e);
323        }
324
325        result
326    }
327
328    async fn rollback_inner(&self, count: usize) -> Result<Vec<String>> {
329        self.ensure_migrations_table().await?;
330
331        // Get the N most recent migrations with their down_sql
332        let rows: Vec<(i32, String, Option<String>)> = sqlx::query_as(
333            "SELECT id, name, down_sql FROM forge_migrations ORDER BY id DESC LIMIT $1",
334        )
335        .bind(count as i32)
336        .fetch_all(&self.pool)
337        .await
338        .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
339
340        if rows.is_empty() {
341            info!("No migrations to rollback");
342            return Ok(Vec::new());
343        }
344
345        let mut rolled_back = Vec::new();
346
347        for (id, name, down_sql) in rows {
348            info!("Rolling back migration: {}", name);
349
350            if let Some(down) = down_sql {
351                // Execute down SQL
352                let statements = split_sql_statements(&down);
353                for statement in statements {
354                    let statement = statement.trim();
355                    if statement.is_empty()
356                        || statement.lines().all(|l| {
357                            let l = l.trim();
358                            l.is_empty() || l.starts_with("--")
359                        })
360                    {
361                        continue;
362                    }
363
364                    sqlx::query(statement)
365                        .execute(&self.pool)
366                        .await
367                        .map_err(|e| {
368                            ForgeError::Database(format!(
369                                "Failed to rollback migration '{}': {}",
370                                name, e
371                            ))
372                        })?;
373                }
374            } else {
375                warn!("Migration '{}' has no down SQL, removing record only", name);
376            }
377
378            // Remove from migrations table
379            sqlx::query("DELETE FROM forge_migrations WHERE id = $1")
380                .bind(id)
381                .execute(&self.pool)
382                .await
383                .map_err(|e| {
384                    ForgeError::Database(format!(
385                        "Failed to remove migration record '{}': {}",
386                        name, e
387                    ))
388                })?;
389
390            info!("Rolled back migration: {}", name);
391            rolled_back.push(name);
392        }
393
394        Ok(rolled_back)
395    }
396
397    /// Get the status of all migrations.
398    pub async fn status(&self, available: &[Migration]) -> Result<MigrationStatus> {
399        self.ensure_migrations_table().await?;
400
401        let applied = self.get_applied_migrations().await?;
402
403        let applied_list: Vec<AppliedMigration> = {
404            let rows: Vec<(String, chrono::DateTime<chrono::Utc>, Option<String>)> =
405                sqlx::query_as(
406                    "SELECT name, applied_at, down_sql FROM forge_migrations ORDER BY id ASC",
407                )
408                .fetch_all(&self.pool)
409                .await
410                .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
411
412            rows.into_iter()
413                .map(|(name, applied_at, down_sql)| AppliedMigration {
414                    name,
415                    applied_at,
416                    has_down: down_sql.is_some(),
417                })
418                .collect()
419        };
420
421        let pending: Vec<String> = available
422            .iter()
423            .filter(|m| !applied.contains(&m.name))
424            .map(|m| m.name.clone())
425            .collect();
426
427        Ok(MigrationStatus {
428            applied: applied_list,
429            pending,
430        })
431    }
432}
433
434/// Information about an applied migration.
435#[derive(Debug, Clone)]
436pub struct AppliedMigration {
437    pub name: String,
438    pub applied_at: chrono::DateTime<chrono::Utc>,
439    pub has_down: bool,
440}
441
442/// Status of migrations.
443#[derive(Debug, Clone)]
444pub struct MigrationStatus {
445    pub applied: Vec<AppliedMigration>,
446    pub pending: Vec<String>,
447}
448
449/// Split SQL into individual statements, respecting dollar-quoted strings.
450/// This handles PL/pgSQL functions that contain semicolons inside $$ delimiters.
451fn split_sql_statements(sql: &str) -> Vec<String> {
452    let mut statements = Vec::new();
453    let mut current = String::new();
454    let mut in_dollar_quote = false;
455    let mut dollar_tag = String::new();
456    let mut chars = sql.chars().peekable();
457
458    while let Some(c) = chars.next() {
459        current.push(c);
460
461        // Check for dollar-quoting start/end
462        if c == '$' {
463            // Look for a dollar-quote tag like $$ or $tag$
464            let mut potential_tag = String::from("$");
465
466            // Collect characters until we hit another $ or non-identifier char
467            while let Some(&next_c) = chars.peek() {
468                if next_c == '$' {
469                    potential_tag.push(chars.next().unwrap());
470                    current.push('$');
471                    break;
472                } else if next_c.is_alphanumeric() || next_c == '_' {
473                    potential_tag.push(chars.next().unwrap());
474                    current.push(potential_tag.chars().last().unwrap());
475                } else {
476                    break;
477                }
478            }
479
480            // Check if this is a valid dollar-quote delimiter (ends with $)
481            if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
482                if in_dollar_quote && potential_tag == dollar_tag {
483                    // End of dollar-quoted string
484                    in_dollar_quote = false;
485                    dollar_tag.clear();
486                } else if !in_dollar_quote {
487                    // Start of dollar-quoted string
488                    in_dollar_quote = true;
489                    dollar_tag = potential_tag;
490                }
491            }
492        }
493
494        // Split on semicolon only if not inside a dollar-quoted string
495        if c == ';' && !in_dollar_quote {
496            let stmt = current.trim().trim_end_matches(';').trim().to_string();
497            if !stmt.is_empty() {
498                statements.push(stmt);
499            }
500            current.clear();
501        }
502    }
503
504    // Don't forget the last statement (might not end with ;)
505    let stmt = current.trim().trim_end_matches(';').trim().to_string();
506    if !stmt.is_empty() {
507        statements.push(stmt);
508    }
509
510    statements
511}
512
513/// Load user migrations from a directory.
514///
515/// Migrations should be named like:
516/// - `0001_create_users.sql`
517/// - `0002_add_posts.sql`
518///
519/// They are sorted alphabetically and executed in order.
520pub fn load_migrations_from_dir(dir: &Path) -> Result<Vec<Migration>> {
521    if !dir.exists() {
522        debug!("Migrations directory does not exist: {:?}", dir);
523        return Ok(Vec::new());
524    }
525
526    let mut migrations = Vec::new();
527
528    let entries = std::fs::read_dir(dir).map_err(ForgeError::Io)?;
529
530    for entry in entries {
531        let entry = entry.map_err(ForgeError::Io)?;
532        let path = entry.path();
533
534        if path.extension().map(|e| e == "sql").unwrap_or(false) {
535            let name = path
536                .file_stem()
537                .and_then(|s| s.to_str())
538                .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
539                .to_string();
540
541            let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
542
543            migrations.push(Migration::parse(name, &content));
544        }
545    }
546
547    // Sort by name (which includes the numeric prefix)
548    migrations.sort_by(|a, b| a.name.cmp(&b.name));
549
550    debug!("Loaded {} user migrations", migrations.len());
551    Ok(migrations)
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557    use std::fs;
558    use tempfile::TempDir;
559
560    #[test]
561    fn test_load_migrations_from_empty_dir() {
562        let dir = TempDir::new().unwrap();
563        let migrations = load_migrations_from_dir(dir.path()).unwrap();
564        assert!(migrations.is_empty());
565    }
566
567    #[test]
568    fn test_load_migrations_from_nonexistent_dir() {
569        let migrations = load_migrations_from_dir(Path::new("/nonexistent/path")).unwrap();
570        assert!(migrations.is_empty());
571    }
572
573    #[test]
574    fn test_load_migrations_sorted() {
575        let dir = TempDir::new().unwrap();
576
577        // Create migrations out of order
578        fs::write(dir.path().join("0002_second.sql"), "SELECT 2;").unwrap();
579        fs::write(dir.path().join("0001_first.sql"), "SELECT 1;").unwrap();
580        fs::write(dir.path().join("0003_third.sql"), "SELECT 3;").unwrap();
581
582        let migrations = load_migrations_from_dir(dir.path()).unwrap();
583        assert_eq!(migrations.len(), 3);
584        assert_eq!(migrations[0].name, "0001_first");
585        assert_eq!(migrations[1].name, "0002_second");
586        assert_eq!(migrations[2].name, "0003_third");
587    }
588
589    #[test]
590    fn test_load_migrations_ignores_non_sql() {
591        let dir = TempDir::new().unwrap();
592
593        fs::write(dir.path().join("0001_migration.sql"), "SELECT 1;").unwrap();
594        fs::write(dir.path().join("readme.txt"), "Not a migration").unwrap();
595        fs::write(dir.path().join("backup.sql.bak"), "Backup").unwrap();
596
597        let migrations = load_migrations_from_dir(dir.path()).unwrap();
598        assert_eq!(migrations.len(), 1);
599        assert_eq!(migrations[0].name, "0001_migration");
600    }
601
602    #[test]
603    fn test_migration_new() {
604        let m = Migration::new("test", "SELECT 1");
605        assert_eq!(m.name, "test");
606        assert_eq!(m.up_sql, "SELECT 1");
607        assert!(m.down_sql.is_none());
608    }
609
610    #[test]
611    fn test_migration_with_down() {
612        let m = Migration::with_down("test", "CREATE TABLE t()", "DROP TABLE t");
613        assert_eq!(m.name, "test");
614        assert_eq!(m.up_sql, "CREATE TABLE t()");
615        assert_eq!(m.down_sql, Some("DROP TABLE t".to_string()));
616    }
617
618    #[test]
619    fn test_migration_parse_up_only() {
620        let content = "CREATE TABLE users (id INT);";
621        let m = Migration::parse("0001_test", content);
622        assert_eq!(m.name, "0001_test");
623        assert_eq!(m.up_sql, "CREATE TABLE users (id INT);");
624        assert!(m.down_sql.is_none());
625    }
626
627    #[test]
628    fn test_migration_parse_with_markers() {
629        let content = r#"
630-- @up
631CREATE TABLE users (
632    id UUID PRIMARY KEY,
633    email VARCHAR(255)
634);
635
636-- @down
637DROP TABLE users;
638"#;
639        let m = Migration::parse("0001_users", content);
640        assert_eq!(m.name, "0001_users");
641        assert!(m.up_sql.contains("CREATE TABLE users"));
642        assert!(!m.up_sql.contains("@up"));
643        assert!(!m.up_sql.contains("DROP TABLE"));
644        assert_eq!(m.down_sql, Some("DROP TABLE users;".to_string()));
645    }
646
647    #[test]
648    fn test_migration_parse_complex() {
649        let content = r#"
650-- @up
651CREATE TABLE posts (
652    id UUID PRIMARY KEY,
653    title TEXT NOT NULL
654);
655CREATE INDEX idx_posts_title ON posts(title);
656
657-- @down
658DROP INDEX idx_posts_title;
659DROP TABLE posts;
660"#;
661        let m = Migration::parse("0002_posts", content);
662        assert!(m.up_sql.contains("CREATE TABLE posts"));
663        assert!(m.up_sql.contains("CREATE INDEX"));
664        let down = m.down_sql.unwrap();
665        assert!(down.contains("DROP INDEX"));
666        assert!(down.contains("DROP TABLE posts"));
667    }
668
669    #[test]
670    fn test_split_simple_statements() {
671        let sql = "SELECT 1; SELECT 2; SELECT 3;";
672        let stmts = super::split_sql_statements(sql);
673        assert_eq!(stmts.len(), 3);
674        assert_eq!(stmts[0], "SELECT 1");
675        assert_eq!(stmts[1], "SELECT 2");
676        assert_eq!(stmts[2], "SELECT 3");
677    }
678
679    #[test]
680    fn test_split_with_dollar_quoted_function() {
681        let sql = r#"
682CREATE FUNCTION test() RETURNS void AS $$
683BEGIN
684    SELECT 1;
685    SELECT 2;
686END;
687$$ LANGUAGE plpgsql;
688
689SELECT 3;
690"#;
691        let stmts = super::split_sql_statements(sql);
692        assert_eq!(stmts.len(), 2);
693        assert!(stmts[0].contains("CREATE FUNCTION"));
694        assert!(stmts[0].contains("$$ LANGUAGE plpgsql"));
695        assert!(stmts[1].contains("SELECT 3"));
696    }
697
698    #[test]
699    fn test_split_preserves_dollar_quote_content() {
700        let sql = r#"
701CREATE FUNCTION notify() RETURNS trigger AS $$
702DECLARE
703    row_id TEXT;
704BEGIN
705    row_id := NEW.id::TEXT;
706    RETURN NEW;
707END;
708$$ LANGUAGE plpgsql;
709"#;
710        let stmts = super::split_sql_statements(sql);
711        assert_eq!(stmts.len(), 1);
712        assert!(stmts[0].contains("row_id := NEW.id::TEXT"));
713    }
714}