forge_runtime/migrations/
runner.rs

1//! Migration runner with mesh-safe locking.
2//!
3//! Ensures only one node runs migrations at a time using PostgreSQL advisory locks.
4
5use forge_core::error::{ForgeError, Result};
6use sqlx::PgPool;
7use std::collections::HashSet;
8use std::path::Path;
9use tracing::{debug, info, warn};
10
11/// Lock ID for migration advisory lock (arbitrary but consistent).
12/// Using a fixed value derived from "FORGE" ascii values.
13const MIGRATION_LOCK_ID: i64 = 0x464F524745; // "FORGE" in hex
14
15/// A single migration with up and optional down SQL.
16#[derive(Debug, Clone)]
17pub struct Migration {
18    /// Unique name/identifier (e.g., "0001_forge_internal" or "0002_create_users").
19    pub name: String,
20    /// SQL to execute for upgrade (forward migration).
21    pub up_sql: String,
22    /// SQL to execute for rollback (optional).
23    pub down_sql: Option<String>,
24}
25
26impl Migration {
27    /// Create a migration with only up SQL (no rollback).
28    pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
29        Self {
30            name: name.into(),
31            up_sql: sql.into(),
32            down_sql: None,
33        }
34    }
35
36    /// Create a migration with both up and down SQL.
37    pub fn with_down(
38        name: impl Into<String>,
39        up_sql: impl Into<String>,
40        down_sql: impl Into<String>,
41    ) -> Self {
42        Self {
43            name: name.into(),
44            up_sql: up_sql.into(),
45            down_sql: Some(down_sql.into()),
46        }
47    }
48
49    /// Parse migration content that may contain -- @up and -- @down markers.
50    pub fn parse(name: impl Into<String>, content: &str) -> Self {
51        let name = name.into();
52        let (up_sql, down_sql) = parse_migration_content(content);
53        Self {
54            name,
55            up_sql,
56            down_sql,
57        }
58    }
59}
60
61/// Parse migration content, splitting on -- @down marker.
62/// Returns (up_sql, Option<down_sql>).
63fn parse_migration_content(content: &str) -> (String, Option<String>) {
64    // Look for -- @down marker (case insensitive, with optional whitespace)
65    let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
66
67    for pattern in down_marker_patterns {
68        if let Some(idx) = content.find(pattern) {
69            let up_part = &content[..idx];
70            let down_part = &content[idx + pattern.len()..];
71
72            // Clean up the up part (remove -- @up marker if present)
73            let up_sql = up_part
74                .replace("-- @up", "")
75                .replace("--@up", "")
76                .replace("-- @UP", "")
77                .replace("--@UP", "")
78                .trim()
79                .to_string();
80
81            let down_sql = down_part.trim().to_string();
82
83            if down_sql.is_empty() {
84                return (up_sql, None);
85            }
86            return (up_sql, Some(down_sql));
87        }
88    }
89
90    // No @down marker found - treat entire content as up SQL
91    let up_sql = content
92        .replace("-- @up", "")
93        .replace("--@up", "")
94        .replace("-- @UP", "")
95        .replace("--@UP", "")
96        .trim()
97        .to_string();
98
99    (up_sql, None)
100}
101
102/// Migration runner that handles both built-in and user migrations.
103pub struct MigrationRunner {
104    pool: PgPool,
105}
106
107impl MigrationRunner {
108    pub fn new(pool: PgPool) -> Self {
109        Self { pool }
110    }
111
112    /// Run all pending migrations with mesh-safe locking.
113    ///
114    /// This acquires an exclusive advisory lock before running migrations,
115    /// ensuring only one node in the cluster runs migrations at a time.
116    pub async fn run(&self, user_migrations: Vec<Migration>) -> Result<()> {
117        // Acquire exclusive lock (blocks until acquired)
118        self.acquire_lock().await?;
119
120        let result = self.run_migrations_inner(user_migrations).await;
121
122        // Always release lock, even on error
123        if let Err(e) = self.release_lock().await {
124            warn!("Failed to release migration lock: {}", e);
125        }
126
127        result
128    }
129
130    async fn run_migrations_inner(&self, user_migrations: Vec<Migration>) -> Result<()> {
131        // Ensure migration tracking table exists
132        self.ensure_migrations_table().await?;
133
134        // Get already-applied migrations
135        let applied = self.get_applied_migrations().await?;
136        debug!("Already applied migrations: {:?}", applied);
137
138        // Run built-in FORGE migrations first
139        let builtin = super::builtin::get_builtin_migrations();
140        for migration in builtin {
141            if !applied.contains(&migration.name) {
142                self.apply_migration(&migration).await?;
143            }
144        }
145
146        // Then run user migrations
147        for migration in user_migrations {
148            if !applied.contains(&migration.name) {
149                self.apply_migration(&migration).await?;
150            }
151        }
152
153        Ok(())
154    }
155
156    async fn acquire_lock(&self) -> Result<()> {
157        debug!("Acquiring migration lock...");
158        sqlx::query("SELECT pg_advisory_lock($1)")
159            .bind(MIGRATION_LOCK_ID)
160            .execute(&self.pool)
161            .await
162            .map_err(|e| {
163                ForgeError::Database(format!("Failed to acquire migration lock: {}", e))
164            })?;
165        debug!("Migration lock acquired");
166        Ok(())
167    }
168
169    async fn release_lock(&self) -> Result<()> {
170        sqlx::query("SELECT pg_advisory_unlock($1)")
171            .bind(MIGRATION_LOCK_ID)
172            .execute(&self.pool)
173            .await
174            .map_err(|e| {
175                ForgeError::Database(format!("Failed to release migration lock: {}", e))
176            })?;
177        debug!("Migration lock released");
178        Ok(())
179    }
180
181    async fn ensure_migrations_table(&self) -> Result<()> {
182        // Create table if not exists
183        sqlx::query(
184            r#"
185            CREATE TABLE IF NOT EXISTS forge_migrations (
186                id SERIAL PRIMARY KEY,
187                name VARCHAR(255) UNIQUE NOT NULL,
188                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
189                down_sql TEXT
190            )
191            "#,
192        )
193        .execute(&self.pool)
194        .await
195        .map_err(|e| ForgeError::Database(format!("Failed to create migrations table: {}", e)))?;
196
197        // Add down_sql column if it doesn't exist (for existing installations)
198        sqlx::query(
199            r#"
200            ALTER TABLE forge_migrations
201            ADD COLUMN IF NOT EXISTS down_sql TEXT
202            "#,
203        )
204        .execute(&self.pool)
205        .await
206        .map_err(|e| ForgeError::Database(format!("Failed to add down_sql column: {}", e)))?;
207
208        Ok(())
209    }
210
211    async fn get_applied_migrations(&self) -> Result<HashSet<String>> {
212        let rows: Vec<(String,)> = sqlx::query_as("SELECT name FROM forge_migrations")
213            .fetch_all(&self.pool)
214            .await
215            .map_err(|e| {
216                ForgeError::Database(format!("Failed to get applied migrations: {}", e))
217            })?;
218
219        Ok(rows.into_iter().map(|(name,)| name).collect())
220    }
221
222    async fn apply_migration(&self, migration: &Migration) -> Result<()> {
223        info!("Applying migration: {}", migration.name);
224
225        // Split migration into individual statements, respecting dollar-quoted strings
226        let statements = split_sql_statements(&migration.up_sql);
227
228        for statement in statements {
229            let statement = statement.trim();
230
231            // Skip empty statements or comment-only blocks
232            if statement.is_empty()
233                || statement.lines().all(|l| {
234                    let l = l.trim();
235                    l.is_empty() || l.starts_with("--")
236                })
237            {
238                continue;
239            }
240
241            sqlx::query(statement)
242                .execute(&self.pool)
243                .await
244                .map_err(|e| {
245                    ForgeError::Database(format!(
246                        "Failed to apply migration '{}': {}",
247                        migration.name, e
248                    ))
249                })?;
250        }
251
252        // Record it as applied (with down_sql for potential rollback)
253        sqlx::query("INSERT INTO forge_migrations (name, down_sql) VALUES ($1, $2)")
254            .bind(&migration.name)
255            .bind(&migration.down_sql)
256            .execute(&self.pool)
257            .await
258            .map_err(|e| {
259                ForgeError::Database(format!(
260                    "Failed to record migration '{}': {}",
261                    migration.name, e
262                ))
263            })?;
264
265        info!("Migration applied: {}", migration.name);
266        Ok(())
267    }
268
269    /// Rollback N migrations (most recent first).
270    pub async fn rollback(&self, count: usize) -> Result<Vec<String>> {
271        if count == 0 {
272            return Ok(Vec::new());
273        }
274
275        // Acquire exclusive lock
276        self.acquire_lock().await?;
277
278        let result = self.rollback_inner(count).await;
279
280        // Always release lock
281        if let Err(e) = self.release_lock().await {
282            warn!("Failed to release migration lock: {}", e);
283        }
284
285        result
286    }
287
288    async fn rollback_inner(&self, count: usize) -> Result<Vec<String>> {
289        self.ensure_migrations_table().await?;
290
291        // Get the N most recent migrations with their down_sql
292        let rows: Vec<(i32, String, Option<String>)> = sqlx::query_as(
293            "SELECT id, name, down_sql FROM forge_migrations ORDER BY id DESC LIMIT $1",
294        )
295        .bind(count as i32)
296        .fetch_all(&self.pool)
297        .await
298        .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
299
300        if rows.is_empty() {
301            info!("No migrations to rollback");
302            return Ok(Vec::new());
303        }
304
305        let mut rolled_back = Vec::new();
306
307        for (id, name, down_sql) in rows {
308            info!("Rolling back migration: {}", name);
309
310            if let Some(down) = down_sql {
311                // Execute down SQL
312                let statements = split_sql_statements(&down);
313                for statement in statements {
314                    let statement = statement.trim();
315                    if statement.is_empty()
316                        || statement.lines().all(|l| {
317                            let l = l.trim();
318                            l.is_empty() || l.starts_with("--")
319                        })
320                    {
321                        continue;
322                    }
323
324                    sqlx::query(statement)
325                        .execute(&self.pool)
326                        .await
327                        .map_err(|e| {
328                            ForgeError::Database(format!(
329                                "Failed to rollback migration '{}': {}",
330                                name, e
331                            ))
332                        })?;
333                }
334            } else {
335                warn!("Migration '{}' has no down SQL, removing record only", name);
336            }
337
338            // Remove from migrations table
339            sqlx::query("DELETE FROM forge_migrations WHERE id = $1")
340                .bind(id)
341                .execute(&self.pool)
342                .await
343                .map_err(|e| {
344                    ForgeError::Database(format!(
345                        "Failed to remove migration record '{}': {}",
346                        name, e
347                    ))
348                })?;
349
350            info!("Rolled back migration: {}", name);
351            rolled_back.push(name);
352        }
353
354        Ok(rolled_back)
355    }
356
357    /// Get the status of all migrations.
358    pub async fn status(&self, available: &[Migration]) -> Result<MigrationStatus> {
359        self.ensure_migrations_table().await?;
360
361        let applied = self.get_applied_migrations().await?;
362
363        let applied_list: Vec<AppliedMigration> = {
364            let rows: Vec<(String, chrono::DateTime<chrono::Utc>, Option<String>)> =
365                sqlx::query_as(
366                    "SELECT name, applied_at, down_sql FROM forge_migrations ORDER BY id ASC",
367                )
368                .fetch_all(&self.pool)
369                .await
370                .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
371
372            rows.into_iter()
373                .map(|(name, applied_at, down_sql)| AppliedMigration {
374                    name,
375                    applied_at,
376                    has_down: down_sql.is_some(),
377                })
378                .collect()
379        };
380
381        let pending: Vec<String> = available
382            .iter()
383            .filter(|m| !applied.contains(&m.name))
384            .map(|m| m.name.clone())
385            .collect();
386
387        Ok(MigrationStatus {
388            applied: applied_list,
389            pending,
390        })
391    }
392}
393
394/// Information about an applied migration.
395#[derive(Debug, Clone)]
396pub struct AppliedMigration {
397    pub name: String,
398    pub applied_at: chrono::DateTime<chrono::Utc>,
399    pub has_down: bool,
400}
401
402/// Status of migrations.
403#[derive(Debug, Clone)]
404pub struct MigrationStatus {
405    pub applied: Vec<AppliedMigration>,
406    pub pending: Vec<String>,
407}
408
409/// Split SQL into individual statements, respecting dollar-quoted strings.
410/// This handles PL/pgSQL functions that contain semicolons inside $$ delimiters.
411fn split_sql_statements(sql: &str) -> Vec<String> {
412    let mut statements = Vec::new();
413    let mut current = String::new();
414    let mut in_dollar_quote = false;
415    let mut dollar_tag = String::new();
416    let mut chars = sql.chars().peekable();
417
418    while let Some(c) = chars.next() {
419        current.push(c);
420
421        // Check for dollar-quoting start/end
422        if c == '$' {
423            // Look for a dollar-quote tag like $$ or $tag$
424            let mut potential_tag = String::from("$");
425
426            // Collect characters until we hit another $ or non-identifier char
427            while let Some(&next_c) = chars.peek() {
428                if next_c == '$' {
429                    potential_tag.push(chars.next().unwrap());
430                    current.push('$');
431                    break;
432                } else if next_c.is_alphanumeric() || next_c == '_' {
433                    potential_tag.push(chars.next().unwrap());
434                    current.push(potential_tag.chars().last().unwrap());
435                } else {
436                    break;
437                }
438            }
439
440            // Check if this is a valid dollar-quote delimiter (ends with $)
441            if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
442                if in_dollar_quote && potential_tag == dollar_tag {
443                    // End of dollar-quoted string
444                    in_dollar_quote = false;
445                    dollar_tag.clear();
446                } else if !in_dollar_quote {
447                    // Start of dollar-quoted string
448                    in_dollar_quote = true;
449                    dollar_tag = potential_tag;
450                }
451            }
452        }
453
454        // Split on semicolon only if not inside a dollar-quoted string
455        if c == ';' && !in_dollar_quote {
456            let stmt = current.trim().trim_end_matches(';').trim().to_string();
457            if !stmt.is_empty() {
458                statements.push(stmt);
459            }
460            current.clear();
461        }
462    }
463
464    // Don't forget the last statement (might not end with ;)
465    let stmt = current.trim().trim_end_matches(';').trim().to_string();
466    if !stmt.is_empty() {
467        statements.push(stmt);
468    }
469
470    statements
471}
472
473/// Load user migrations from a directory.
474///
475/// Migrations should be named like:
476/// - `0001_create_users.sql`
477/// - `0002_add_posts.sql`
478///
479/// They are sorted alphabetically and executed in order.
480pub fn load_migrations_from_dir(dir: &Path) -> Result<Vec<Migration>> {
481    if !dir.exists() {
482        debug!("Migrations directory does not exist: {:?}", dir);
483        return Ok(Vec::new());
484    }
485
486    let mut migrations = Vec::new();
487
488    let entries = std::fs::read_dir(dir).map_err(ForgeError::Io)?;
489
490    for entry in entries {
491        let entry = entry.map_err(ForgeError::Io)?;
492        let path = entry.path();
493
494        if path.extension().map(|e| e == "sql").unwrap_or(false) {
495            let name = path
496                .file_stem()
497                .and_then(|s| s.to_str())
498                .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
499                .to_string();
500
501            let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
502
503            migrations.push(Migration::parse(name, &content));
504        }
505    }
506
507    // Sort by name (which includes the numeric prefix)
508    migrations.sort_by(|a, b| a.name.cmp(&b.name));
509
510    debug!("Loaded {} user migrations", migrations.len());
511    Ok(migrations)
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517    use std::fs;
518    use tempfile::TempDir;
519
520    #[test]
521    fn test_load_migrations_from_empty_dir() {
522        let dir = TempDir::new().unwrap();
523        let migrations = load_migrations_from_dir(dir.path()).unwrap();
524        assert!(migrations.is_empty());
525    }
526
527    #[test]
528    fn test_load_migrations_from_nonexistent_dir() {
529        let migrations = load_migrations_from_dir(Path::new("/nonexistent/path")).unwrap();
530        assert!(migrations.is_empty());
531    }
532
533    #[test]
534    fn test_load_migrations_sorted() {
535        let dir = TempDir::new().unwrap();
536
537        // Create migrations out of order
538        fs::write(dir.path().join("0002_second.sql"), "SELECT 2;").unwrap();
539        fs::write(dir.path().join("0001_first.sql"), "SELECT 1;").unwrap();
540        fs::write(dir.path().join("0003_third.sql"), "SELECT 3;").unwrap();
541
542        let migrations = load_migrations_from_dir(dir.path()).unwrap();
543        assert_eq!(migrations.len(), 3);
544        assert_eq!(migrations[0].name, "0001_first");
545        assert_eq!(migrations[1].name, "0002_second");
546        assert_eq!(migrations[2].name, "0003_third");
547    }
548
549    #[test]
550    fn test_load_migrations_ignores_non_sql() {
551        let dir = TempDir::new().unwrap();
552
553        fs::write(dir.path().join("0001_migration.sql"), "SELECT 1;").unwrap();
554        fs::write(dir.path().join("readme.txt"), "Not a migration").unwrap();
555        fs::write(dir.path().join("backup.sql.bak"), "Backup").unwrap();
556
557        let migrations = load_migrations_from_dir(dir.path()).unwrap();
558        assert_eq!(migrations.len(), 1);
559        assert_eq!(migrations[0].name, "0001_migration");
560    }
561
562    #[test]
563    fn test_migration_new() {
564        let m = Migration::new("test", "SELECT 1");
565        assert_eq!(m.name, "test");
566        assert_eq!(m.up_sql, "SELECT 1");
567        assert!(m.down_sql.is_none());
568    }
569
570    #[test]
571    fn test_migration_with_down() {
572        let m = Migration::with_down("test", "CREATE TABLE t()", "DROP TABLE t");
573        assert_eq!(m.name, "test");
574        assert_eq!(m.up_sql, "CREATE TABLE t()");
575        assert_eq!(m.down_sql, Some("DROP TABLE t".to_string()));
576    }
577
578    #[test]
579    fn test_migration_parse_up_only() {
580        let content = "CREATE TABLE users (id INT);";
581        let m = Migration::parse("0001_test", content);
582        assert_eq!(m.name, "0001_test");
583        assert_eq!(m.up_sql, "CREATE TABLE users (id INT);");
584        assert!(m.down_sql.is_none());
585    }
586
587    #[test]
588    fn test_migration_parse_with_markers() {
589        let content = r#"
590-- @up
591CREATE TABLE users (
592    id UUID PRIMARY KEY,
593    email VARCHAR(255)
594);
595
596-- @down
597DROP TABLE users;
598"#;
599        let m = Migration::parse("0001_users", content);
600        assert_eq!(m.name, "0001_users");
601        assert!(m.up_sql.contains("CREATE TABLE users"));
602        assert!(!m.up_sql.contains("@up"));
603        assert!(!m.up_sql.contains("DROP TABLE"));
604        assert_eq!(m.down_sql, Some("DROP TABLE users;".to_string()));
605    }
606
607    #[test]
608    fn test_migration_parse_complex() {
609        let content = r#"
610-- @up
611CREATE TABLE posts (
612    id UUID PRIMARY KEY,
613    title TEXT NOT NULL
614);
615CREATE INDEX idx_posts_title ON posts(title);
616
617-- @down
618DROP INDEX idx_posts_title;
619DROP TABLE posts;
620"#;
621        let m = Migration::parse("0002_posts", content);
622        assert!(m.up_sql.contains("CREATE TABLE posts"));
623        assert!(m.up_sql.contains("CREATE INDEX"));
624        let down = m.down_sql.unwrap();
625        assert!(down.contains("DROP INDEX"));
626        assert!(down.contains("DROP TABLE posts"));
627    }
628
629    #[test]
630    fn test_split_simple_statements() {
631        let sql = "SELECT 1; SELECT 2; SELECT 3;";
632        let stmts = super::split_sql_statements(sql);
633        assert_eq!(stmts.len(), 3);
634        assert_eq!(stmts[0], "SELECT 1");
635        assert_eq!(stmts[1], "SELECT 2");
636        assert_eq!(stmts[2], "SELECT 3");
637    }
638
639    #[test]
640    fn test_split_with_dollar_quoted_function() {
641        let sql = r#"
642CREATE FUNCTION test() RETURNS void AS $$
643BEGIN
644    SELECT 1;
645    SELECT 2;
646END;
647$$ LANGUAGE plpgsql;
648
649SELECT 3;
650"#;
651        let stmts = super::split_sql_statements(sql);
652        assert_eq!(stmts.len(), 2);
653        assert!(stmts[0].contains("CREATE FUNCTION"));
654        assert!(stmts[0].contains("$$ LANGUAGE plpgsql"));
655        assert!(stmts[1].contains("SELECT 3"));
656    }
657
658    #[test]
659    fn test_split_preserves_dollar_quote_content() {
660        let sql = r#"
661CREATE FUNCTION notify() RETURNS trigger AS $$
662DECLARE
663    row_id TEXT;
664BEGIN
665    row_id := NEW.id::TEXT;
666    RETURN NEW;
667END;
668$$ LANGUAGE plpgsql;
669"#;
670        let stmts = super::split_sql_statements(sql);
671        assert_eq!(stmts.len(), 1);
672        assert!(stmts[0].contains("row_id := NEW.id::TEXT"));
673    }
674}