Skip to main content

forge_runtime/migrations/
runner.rs

1//! Migration runner with mesh-safe locking.
2//!
3//! Ensures only one node runs migrations at a time using PostgreSQL advisory locks.
4//!
5//! # Migration Types
6//!
7//! This runner handles two types of migrations:
8//!
9//! 1. **System migrations** (`__forge_vXXX`): Internal FORGE schema changes.
10//!    These are versioned numerically and always run before user migrations.
11//!
12//! 2. **User migrations** (`XXXX_name.sql`): Application-specific schema changes.
13//!    These are sorted alphabetically by name.
14
15use forge_core::error::{ForgeError, Result};
16use sqlx::{PgPool, Postgres};
17use std::collections::HashSet;
18use std::path::Path;
19use tracing::{debug, info, warn};
20
21use super::builtin::extract_version;
22
23/// Lock ID for migration advisory lock (arbitrary but consistent).
24/// Using a fixed value derived from "FORGE" ascii values.
25const MIGRATION_LOCK_ID: i64 = 0x464F524745; // "FORGE" in hex
26
27/// A single migration with up and optional down SQL.
28#[derive(Debug, Clone)]
29pub struct Migration {
30    /// Unique name/identifier (e.g., "0001_forge_internal" or "0002_create_users").
31    pub name: String,
32    /// SQL to execute for upgrade (forward migration).
33    pub up_sql: String,
34    /// SQL to execute for rollback (optional).
35    pub down_sql: Option<String>,
36}
37
38impl Migration {
39    /// Create a migration with only up SQL (no rollback).
40    pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
41        Self {
42            name: name.into(),
43            up_sql: sql.into(),
44            down_sql: None,
45        }
46    }
47
48    /// Create a migration with both up and down SQL.
49    pub fn with_down(
50        name: impl Into<String>,
51        up_sql: impl Into<String>,
52        down_sql: impl Into<String>,
53    ) -> Self {
54        Self {
55            name: name.into(),
56            up_sql: up_sql.into(),
57            down_sql: Some(down_sql.into()),
58        }
59    }
60
61    /// Parse migration content that may contain -- @up and -- @down markers.
62    pub fn parse(name: impl Into<String>, content: &str) -> Self {
63        let name = name.into();
64        let (up_sql, down_sql) = parse_migration_content(content);
65        Self {
66            name,
67            up_sql,
68            down_sql,
69        }
70    }
71}
72
73/// Parse migration content, splitting on -- @down marker.
74/// Returns (up_sql, Option<down_sql>).
75fn parse_migration_content(content: &str) -> (String, Option<String>) {
76    // Look for -- @down marker (case insensitive, with optional whitespace)
77    let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
78
79    for pattern in down_marker_patterns {
80        if let Some(idx) = content.find(pattern) {
81            let up_part = &content[..idx];
82            let down_part = &content[idx + pattern.len()..];
83
84            // Clean up the up part (remove -- @up marker if present)
85            let up_sql = up_part
86                .replace("-- @up", "")
87                .replace("--@up", "")
88                .replace("-- @UP", "")
89                .replace("--@UP", "")
90                .trim()
91                .to_string();
92
93            let down_sql = down_part.trim().to_string();
94
95            if down_sql.is_empty() {
96                return (up_sql, None);
97            }
98            return (up_sql, Some(down_sql));
99        }
100    }
101
102    // No @down marker found - treat entire content as up SQL
103    let up_sql = content
104        .replace("-- @up", "")
105        .replace("--@up", "")
106        .replace("-- @UP", "")
107        .replace("--@UP", "")
108        .trim()
109        .to_string();
110
111    (up_sql, None)
112}
113
114/// Migration runner that handles both built-in and user migrations.
115pub struct MigrationRunner {
116    pool: PgPool,
117}
118
119impl MigrationRunner {
120    pub fn new(pool: PgPool) -> Self {
121        Self { pool }
122    }
123
124    /// Run all pending migrations with mesh-safe locking.
125    ///
126    /// This acquires an exclusive advisory lock before running migrations,
127    /// ensuring only one node in the cluster runs migrations at a time.
128    pub async fn run(&self, user_migrations: Vec<Migration>) -> Result<()> {
129        // Acquire exclusive lock (blocks until acquired) on a dedicated connection.
130        let mut lock_conn = self.acquire_lock_connection().await?;
131
132        let result = self.run_migrations_inner(user_migrations).await;
133
134        // Always release lock, even on error
135        if let Err(e) = self.release_lock_connection(&mut lock_conn).await {
136            warn!("Failed to release migration lock: {}", e);
137        }
138
139        result
140    }
141
142    async fn run_migrations_inner(&self, user_migrations: Vec<Migration>) -> Result<()> {
143        // Ensure migration tracking table exists
144        self.ensure_migrations_table().await?;
145
146        // Get already-applied migrations
147        let applied = self.get_applied_migrations().await?;
148        debug!("Already applied migrations: {:?}", applied);
149
150        // Calculate the highest system version already applied
151        let max_applied_version = self.get_max_system_version(&applied);
152        debug!("Max applied system version: {:?}", max_applied_version);
153
154        // Run built-in FORGE system migrations first (in version order)
155        let system_migrations = super::builtin::get_system_migrations();
156        for sys_migration in system_migrations {
157            // Skip if this version is already applied
158            if let Some(max_ver) = max_applied_version
159                && sys_migration.version <= max_ver
160            {
161                debug!(
162                    "Skipping system migration v{} (already at v{})",
163                    sys_migration.version, max_ver
164                );
165                continue;
166            }
167
168            let migration = sys_migration.to_migration();
169            info!(
170                "Applying system migration: {} ({})",
171                migration.name, sys_migration.description
172            );
173            self.apply_migration(&migration).await?;
174        }
175
176        // Then run user migrations (sorted by name)
177        for migration in user_migrations {
178            if !applied.contains(&migration.name) {
179                self.apply_migration(&migration).await?;
180            }
181        }
182
183        Ok(())
184    }
185
186    /// Get the maximum system migration version that has been applied.
187    fn get_max_system_version(&self, applied: &HashSet<String>) -> Option<u32> {
188        applied
189            .iter()
190            .filter_map(|name| extract_version(name))
191            .max()
192    }
193
194    async fn acquire_lock_connection(&self) -> Result<sqlx::pool::PoolConnection<Postgres>> {
195        debug!("Acquiring migration lock...");
196        let mut conn = self.pool.acquire().await.map_err(|e| {
197            ForgeError::Database(format!("Failed to acquire lock connection: {}", e))
198        })?;
199
200        sqlx::query_scalar!("SELECT pg_advisory_lock($1)", MIGRATION_LOCK_ID)
201            .fetch_one(&mut *conn)
202            .await
203            .map_err(|e| {
204                ForgeError::Database(format!("Failed to acquire migration lock: {}", e))
205            })?;
206        debug!("Migration lock acquired");
207        Ok(conn)
208    }
209
210    async fn release_lock_connection(
211        &self,
212        conn: &mut sqlx::pool::PoolConnection<Postgres>,
213    ) -> Result<()> {
214        sqlx::query_scalar!("SELECT pg_advisory_unlock($1)", MIGRATION_LOCK_ID)
215            .fetch_one(&mut **conn)
216            .await
217            .map_err(|e| {
218                ForgeError::Database(format!("Failed to release migration lock: {}", e))
219            })?;
220        debug!("Migration lock released");
221        Ok(())
222    }
223
224    async fn ensure_migrations_table(&self) -> Result<()> {
225        // Create table if not exists
226        sqlx::query(
227            r#"
228            CREATE TABLE IF NOT EXISTS forge_migrations (
229                id SERIAL PRIMARY KEY,
230                name VARCHAR(255) UNIQUE NOT NULL,
231                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
232                down_sql TEXT
233            )
234            "#,
235        )
236        .execute(&self.pool)
237        .await
238        .map_err(|e| ForgeError::Database(format!("Failed to create migrations table: {}", e)))?;
239
240        Ok(())
241    }
242
243    async fn get_applied_migrations(&self) -> Result<HashSet<String>> {
244        let rows = sqlx::query!("SELECT name FROM forge_migrations")
245            .fetch_all(&self.pool)
246            .await
247            .map_err(|e| {
248                ForgeError::Database(format!("Failed to get applied migrations: {}", e))
249            })?;
250
251        Ok(rows.into_iter().map(|row| row.name).collect())
252    }
253
254    async fn apply_migration(&self, migration: &Migration) -> Result<()> {
255        info!("Applying migration: {}", migration.name);
256
257        // Split migration into individual statements, respecting dollar-quoted strings
258        let statements = split_sql_statements(&migration.up_sql);
259
260        for statement in statements {
261            let statement = statement.trim();
262
263            // Skip empty statements or comment-only blocks
264            if statement.is_empty()
265                || statement.lines().all(|l| {
266                    let l = l.trim();
267                    l.is_empty() || l.starts_with("--")
268                })
269            {
270                continue;
271            }
272
273            sqlx::query(statement)
274                .execute(&self.pool)
275                .await
276                .map_err(|e| {
277                    ForgeError::Database(format!(
278                        "Failed to apply migration '{}': {}",
279                        migration.name, e
280                    ))
281                })?;
282        }
283
284        // Record it as applied (with down_sql for potential rollback)
285        sqlx::query!(
286            "INSERT INTO forge_migrations (name, down_sql) VALUES ($1, $2)",
287            &migration.name,
288            migration.down_sql as _,
289        )
290        .execute(&self.pool)
291        .await
292        .map_err(|e| {
293            ForgeError::Database(format!(
294                "Failed to record migration '{}': {}",
295                migration.name, e
296            ))
297        })?;
298
299        info!("Migration applied: {}", migration.name);
300        Ok(())
301    }
302
303    /// Rollback N migrations (most recent first).
304    pub async fn rollback(&self, count: usize) -> Result<Vec<String>> {
305        if count == 0 {
306            return Ok(Vec::new());
307        }
308
309        // Acquire exclusive lock on a dedicated connection.
310        let mut lock_conn = self.acquire_lock_connection().await?;
311
312        let result = self.rollback_inner(count).await;
313
314        // Always release lock
315        if let Err(e) = self.release_lock_connection(&mut lock_conn).await {
316            warn!("Failed to release migration lock: {}", e);
317        }
318
319        result
320    }
321
322    async fn rollback_inner(&self, count: usize) -> Result<Vec<String>> {
323        self.ensure_migrations_table().await?;
324
325        // Get the N most recent migrations with their down_sql
326        let rows = sqlx::query!(
327            "SELECT id, name, down_sql FROM forge_migrations ORDER BY id DESC LIMIT $1",
328            count as i32
329        )
330        .fetch_all(&self.pool)
331        .await
332        .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
333
334        if rows.is_empty() {
335            info!("No migrations to rollback");
336            return Ok(Vec::new());
337        }
338
339        let mut rolled_back = Vec::new();
340
341        for row in rows {
342            let id = row.id;
343            let name = row.name;
344            let down_sql = row.down_sql;
345            info!("Rolling back migration: {}", name);
346
347            if let Some(down) = down_sql {
348                // Execute down SQL
349                let statements = split_sql_statements(&down);
350                for statement in statements {
351                    let statement = statement.trim();
352                    if statement.is_empty()
353                        || statement.lines().all(|l| {
354                            let l = l.trim();
355                            l.is_empty() || l.starts_with("--")
356                        })
357                    {
358                        continue;
359                    }
360
361                    sqlx::query(statement)
362                        .execute(&self.pool)
363                        .await
364                        .map_err(|e| {
365                            ForgeError::Database(format!(
366                                "Failed to rollback migration '{}': {}",
367                                name, e
368                            ))
369                        })?;
370                }
371            } else {
372                warn!("Migration '{}' has no down SQL, removing record only", name);
373            }
374
375            // Remove from migrations table
376            sqlx::query!("DELETE FROM forge_migrations WHERE id = $1", id)
377                .execute(&self.pool)
378                .await
379                .map_err(|e| {
380                    ForgeError::Database(format!(
381                        "Failed to remove migration record '{}': {}",
382                        name, e
383                    ))
384                })?;
385
386            info!("Rolled back migration: {}", name);
387            rolled_back.push(name);
388        }
389
390        Ok(rolled_back)
391    }
392
393    /// Get the status of all migrations.
394    pub async fn status(&self, available: &[Migration]) -> Result<MigrationStatus> {
395        self.ensure_migrations_table().await?;
396
397        let applied = self.get_applied_migrations().await?;
398
399        let applied_list: Vec<AppliedMigration> = {
400            let rows = sqlx::query!(
401                "SELECT name, applied_at, down_sql FROM forge_migrations ORDER BY id ASC"
402            )
403            .fetch_all(&self.pool)
404            .await
405            .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
406
407            rows.into_iter()
408                .map(|row| AppliedMigration {
409                    name: row.name,
410                    applied_at: row.applied_at,
411                    has_down: row.down_sql.is_some(),
412                })
413                .collect()
414        };
415
416        let pending: Vec<String> = available
417            .iter()
418            .filter(|m| !applied.contains(&m.name))
419            .map(|m| m.name.clone())
420            .collect();
421
422        Ok(MigrationStatus {
423            applied: applied_list,
424            pending,
425        })
426    }
427}
428
429/// Information about an applied migration.
430#[derive(Debug, Clone)]
431pub struct AppliedMigration {
432    pub name: String,
433    pub applied_at: chrono::DateTime<chrono::Utc>,
434    pub has_down: bool,
435}
436
437/// Status of migrations.
438#[derive(Debug, Clone)]
439pub struct MigrationStatus {
440    pub applied: Vec<AppliedMigration>,
441    pub pending: Vec<String>,
442}
443
444/// Split SQL into individual statements, respecting dollar-quoted strings.
445/// This handles PL/pgSQL functions that contain semicolons inside $$ delimiters.
446fn split_sql_statements(sql: &str) -> Vec<String> {
447    let mut statements = Vec::new();
448    let mut current = String::new();
449    let mut in_dollar_quote = false;
450    let mut dollar_tag = String::new();
451    let mut chars = sql.chars().peekable();
452
453    while let Some(c) = chars.next() {
454        current.push(c);
455
456        // Check for dollar-quoting start/end
457        if c == '$' {
458            // Look for a dollar-quote tag like $$ or $tag$
459            let mut potential_tag = String::from("$");
460
461            // Collect characters until we hit another $ or non-identifier char
462            while let Some(&next_c) = chars.peek() {
463                if next_c == '$' {
464                    // Safe: peek confirmed the char exists
465                    potential_tag.push(chars.next().expect("peeked char"));
466                    current.push('$');
467                    break;
468                } else if next_c.is_alphanumeric() || next_c == '_' {
469                    let c = chars.next().expect("peeked char");
470                    potential_tag.push(c);
471                    current.push(c);
472                } else {
473                    break;
474                }
475            }
476
477            // Check if this is a valid dollar-quote delimiter (ends with $)
478            if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
479                if in_dollar_quote && potential_tag == dollar_tag {
480                    // End of dollar-quoted string
481                    in_dollar_quote = false;
482                    dollar_tag.clear();
483                } else if !in_dollar_quote {
484                    // Start of dollar-quoted string
485                    in_dollar_quote = true;
486                    dollar_tag = potential_tag;
487                }
488            }
489        }
490
491        // Split on semicolon only if not inside a dollar-quoted string
492        if c == ';' && !in_dollar_quote {
493            let stmt = current.trim().trim_end_matches(';').trim().to_string();
494            if !stmt.is_empty() {
495                statements.push(stmt);
496            }
497            current.clear();
498        }
499    }
500
501    // Don't forget the last statement (might not end with ;)
502    let stmt = current.trim().trim_end_matches(';').trim().to_string();
503    if !stmt.is_empty() {
504        statements.push(stmt);
505    }
506
507    statements
508}
509
510/// Load user migrations from a directory.
511///
512/// Migrations should be named like:
513/// - `0001_create_users.sql`
514/// - `0002_add_posts.sql`
515///
516/// They are sorted alphabetically and executed in order.
517pub fn load_migrations_from_dir(dir: &Path) -> Result<Vec<Migration>> {
518    if !dir.exists() {
519        debug!("Migrations directory does not exist: {:?}", dir);
520        return Ok(Vec::new());
521    }
522
523    let mut migrations = Vec::new();
524
525    let entries = std::fs::read_dir(dir).map_err(ForgeError::Io)?;
526
527    for entry in entries {
528        let entry = entry.map_err(ForgeError::Io)?;
529        let path = entry.path();
530
531        if path.extension().map(|e| e == "sql").unwrap_or(false) {
532            let name = path
533                .file_stem()
534                .and_then(|s| s.to_str())
535                .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
536                .to_string();
537
538            let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
539
540            migrations.push(Migration::parse(name, &content));
541        }
542    }
543
544    // Sort by name (which includes the numeric prefix)
545    migrations.sort_by(|a, b| a.name.cmp(&b.name));
546
547    debug!("Loaded {} user migrations", migrations.len());
548    Ok(migrations)
549}
550
551#[cfg(test)]
552#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
553mod tests {
554    use super::*;
555    use std::collections::HashSet;
556    use std::fs;
557    use tempfile::TempDir;
558
559    #[test]
560    fn test_load_migrations_from_empty_dir() {
561        let dir = TempDir::new().unwrap();
562        let migrations = load_migrations_from_dir(dir.path()).unwrap();
563        assert!(migrations.is_empty());
564    }
565
566    #[test]
567    fn test_load_migrations_from_nonexistent_dir() {
568        let migrations = load_migrations_from_dir(Path::new("/nonexistent/path")).unwrap();
569        assert!(migrations.is_empty());
570    }
571
572    #[test]
573    fn test_load_migrations_sorted() {
574        let dir = TempDir::new().unwrap();
575
576        // Create migrations out of order
577        fs::write(dir.path().join("0002_second.sql"), "SELECT 2;").unwrap();
578        fs::write(dir.path().join("0001_first.sql"), "SELECT 1;").unwrap();
579        fs::write(dir.path().join("0003_third.sql"), "SELECT 3;").unwrap();
580
581        let migrations = load_migrations_from_dir(dir.path()).unwrap();
582        assert_eq!(migrations.len(), 3);
583        assert_eq!(migrations[0].name, "0001_first");
584        assert_eq!(migrations[1].name, "0002_second");
585        assert_eq!(migrations[2].name, "0003_third");
586    }
587
588    #[test]
589    fn test_load_migrations_ignores_non_sql() {
590        let dir = TempDir::new().unwrap();
591
592        fs::write(dir.path().join("0001_migration.sql"), "SELECT 1;").unwrap();
593        fs::write(dir.path().join("readme.txt"), "Not a migration").unwrap();
594        fs::write(dir.path().join("backup.sql.bak"), "Backup").unwrap();
595
596        let migrations = load_migrations_from_dir(dir.path()).unwrap();
597        assert_eq!(migrations.len(), 1);
598        assert_eq!(migrations[0].name, "0001_migration");
599    }
600
601    #[test]
602    fn test_migration_new() {
603        let m = Migration::new("test", "SELECT 1");
604        assert_eq!(m.name, "test");
605        assert_eq!(m.up_sql, "SELECT 1");
606        assert!(m.down_sql.is_none());
607    }
608
609    #[test]
610    fn test_migration_with_down() {
611        let m = Migration::with_down("test", "CREATE TABLE t()", "DROP TABLE t");
612        assert_eq!(m.name, "test");
613        assert_eq!(m.up_sql, "CREATE TABLE t()");
614        assert_eq!(m.down_sql, Some("DROP TABLE t".to_string()));
615    }
616
617    #[test]
618    fn test_migration_parse_up_only() {
619        let content = "CREATE TABLE users (id INT);";
620        let m = Migration::parse("0001_test", content);
621        assert_eq!(m.name, "0001_test");
622        assert_eq!(m.up_sql, "CREATE TABLE users (id INT);");
623        assert!(m.down_sql.is_none());
624    }
625
626    #[test]
627    fn test_migration_parse_with_markers() {
628        let content = r#"
629-- @up
630CREATE TABLE users (
631    id UUID PRIMARY KEY,
632    email VARCHAR(255)
633);
634
635-- @down
636DROP TABLE users;
637"#;
638        let m = Migration::parse("0001_users", content);
639        assert_eq!(m.name, "0001_users");
640        assert!(m.up_sql.contains("CREATE TABLE users"));
641        assert!(!m.up_sql.contains("@up"));
642        assert!(!m.up_sql.contains("DROP TABLE"));
643        assert_eq!(m.down_sql, Some("DROP TABLE users;".to_string()));
644    }
645
646    #[test]
647    fn test_migration_parse_complex() {
648        let content = r#"
649-- @up
650CREATE TABLE posts (
651    id UUID PRIMARY KEY,
652    title TEXT NOT NULL
653);
654CREATE INDEX idx_posts_title ON posts(title);
655
656-- @down
657DROP INDEX idx_posts_title;
658DROP TABLE posts;
659"#;
660        let m = Migration::parse("0002_posts", content);
661        assert!(m.up_sql.contains("CREATE TABLE posts"));
662        assert!(m.up_sql.contains("CREATE INDEX"));
663        let down = m.down_sql.unwrap();
664        assert!(down.contains("DROP INDEX"));
665        assert!(down.contains("DROP TABLE posts"));
666    }
667
668    #[tokio::test]
669    async fn test_get_max_system_version_prefers_highest_applied_version() {
670        let pool = sqlx::postgres::PgPoolOptions::new()
671            .max_connections(1)
672            .connect_lazy("postgres://localhost/nonexistent")
673            .expect("lazy pool must build");
674        let runner = MigrationRunner::new(pool);
675
676        let applied = HashSet::from([
677            "__forge_v003".to_string(),
678            "__forge_v001".to_string(),
679            "0001_user_schema".to_string(),
680        ]);
681
682        assert_eq!(runner.get_max_system_version(&applied), Some(3));
683    }
684
685    #[test]
686    fn test_split_simple_statements() {
687        let sql = "SELECT 1; SELECT 2; SELECT 3;";
688        let stmts = super::split_sql_statements(sql);
689        assert_eq!(stmts.len(), 3);
690        assert_eq!(stmts[0], "SELECT 1");
691        assert_eq!(stmts[1], "SELECT 2");
692        assert_eq!(stmts[2], "SELECT 3");
693    }
694
695    #[test]
696    fn test_split_with_dollar_quoted_function() {
697        let sql = r#"
698CREATE FUNCTION test() RETURNS void AS $$
699BEGIN
700    SELECT 1;
701    SELECT 2;
702END;
703$$ LANGUAGE plpgsql;
704
705SELECT 3;
706"#;
707        let stmts = super::split_sql_statements(sql);
708        assert_eq!(stmts.len(), 2);
709        assert!(stmts[0].contains("CREATE FUNCTION"));
710        assert!(stmts[0].contains("$$ LANGUAGE plpgsql"));
711        assert!(stmts[1].contains("SELECT 3"));
712    }
713
714    #[test]
715    fn test_split_preserves_dollar_quote_content() {
716        let sql = r#"
717CREATE FUNCTION notify() RETURNS trigger AS $$
718DECLARE
719    row_id TEXT;
720BEGIN
721    row_id := NEW.id::TEXT;
722    RETURN NEW;
723END;
724$$ LANGUAGE plpgsql;
725"#;
726        let stmts = super::split_sql_statements(sql);
727        assert_eq!(stmts.len(), 1);
728        assert!(stmts[0].contains("row_id := NEW.id::TEXT"));
729    }
730}