forge_runtime/migrations/
runner.rs1use forge_core::error::{ForgeError, Result};
17use sqlx::PgPool;
18use std::collections::HashSet;
19use std::path::Path;
20use tracing::{debug, info, warn};
21
22use super::builtin::extract_version;
23
24const MIGRATION_LOCK_ID: i64 = 0x464F524745; #[derive(Debug, Clone)]
30pub struct Migration {
31 pub name: String,
33 pub up_sql: String,
35 pub down_sql: Option<String>,
37}
38
39impl Migration {
40 pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
42 Self {
43 name: name.into(),
44 up_sql: sql.into(),
45 down_sql: None,
46 }
47 }
48
49 pub fn with_down(
51 name: impl Into<String>,
52 up_sql: impl Into<String>,
53 down_sql: impl Into<String>,
54 ) -> Self {
55 Self {
56 name: name.into(),
57 up_sql: up_sql.into(),
58 down_sql: Some(down_sql.into()),
59 }
60 }
61
62 pub fn parse(name: impl Into<String>, content: &str) -> Self {
64 let name = name.into();
65 let (up_sql, down_sql) = parse_migration_content(content);
66 Self {
67 name,
68 up_sql,
69 down_sql,
70 }
71 }
72}
73
74fn parse_migration_content(content: &str) -> (String, Option<String>) {
77 let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
79
80 for pattern in down_marker_patterns {
81 if let Some(idx) = content.find(pattern) {
82 let up_part = &content[..idx];
83 let down_part = &content[idx + pattern.len()..];
84
85 let up_sql = up_part
87 .replace("-- @up", "")
88 .replace("--@up", "")
89 .replace("-- @UP", "")
90 .replace("--@UP", "")
91 .trim()
92 .to_string();
93
94 let down_sql = down_part.trim().to_string();
95
96 if down_sql.is_empty() {
97 return (up_sql, None);
98 }
99 return (up_sql, Some(down_sql));
100 }
101 }
102
103 let up_sql = content
105 .replace("-- @up", "")
106 .replace("--@up", "")
107 .replace("-- @UP", "")
108 .replace("--@UP", "")
109 .trim()
110 .to_string();
111
112 (up_sql, None)
113}
114
115pub struct MigrationRunner {
117 pool: PgPool,
118}
119
120impl MigrationRunner {
121 pub fn new(pool: PgPool) -> Self {
122 Self { pool }
123 }
124
125 pub async fn run(&self, user_migrations: Vec<Migration>) -> Result<()> {
130 self.acquire_lock().await?;
132
133 let result = self.run_migrations_inner(user_migrations).await;
134
135 if let Err(e) = self.release_lock().await {
137 warn!("Failed to release migration lock: {}", e);
138 }
139
140 result
141 }
142
143 async fn run_migrations_inner(&self, user_migrations: Vec<Migration>) -> Result<()> {
144 self.ensure_migrations_table().await?;
146
147 let applied = self.get_applied_migrations().await?;
149 debug!("Already applied migrations: {:?}", applied);
150
151 let max_applied_version = self.get_max_system_version(&applied);
153 debug!("Max applied system version: {:?}", max_applied_version);
154
155 let system_migrations = super::builtin::get_system_migrations();
157 for sys_migration in system_migrations {
158 if let Some(max_ver) = max_applied_version {
160 if sys_migration.version <= max_ver {
161 debug!(
162 "Skipping system migration v{} (already at v{})",
163 sys_migration.version, max_ver
164 );
165 continue;
166 }
167 }
168
169 let migration = sys_migration.to_migration();
170 info!(
171 "Applying system migration: {} ({})",
172 migration.name, sys_migration.description
173 );
174 self.apply_migration(&migration).await?;
175 }
176
177 for migration in user_migrations {
179 if !applied.contains(&migration.name) {
180 self.apply_migration(&migration).await?;
181 }
182 }
183
184 Ok(())
185 }
186
187 fn get_max_system_version(&self, applied: &HashSet<String>) -> Option<u32> {
190 applied
191 .iter()
192 .filter_map(|name| extract_version(name))
193 .max()
194 }
195
196 async fn acquire_lock(&self) -> Result<()> {
197 debug!("Acquiring migration lock...");
198 sqlx::query("SELECT pg_advisory_lock($1)")
199 .bind(MIGRATION_LOCK_ID)
200 .execute(&self.pool)
201 .await
202 .map_err(|e| {
203 ForgeError::Database(format!("Failed to acquire migration lock: {}", e))
204 })?;
205 debug!("Migration lock acquired");
206 Ok(())
207 }
208
209 async fn release_lock(&self) -> Result<()> {
210 sqlx::query("SELECT pg_advisory_unlock($1)")
211 .bind(MIGRATION_LOCK_ID)
212 .execute(&self.pool)
213 .await
214 .map_err(|e| {
215 ForgeError::Database(format!("Failed to release migration lock: {}", e))
216 })?;
217 debug!("Migration lock released");
218 Ok(())
219 }
220
221 async fn ensure_migrations_table(&self) -> Result<()> {
222 sqlx::query(
224 r#"
225 CREATE TABLE IF NOT EXISTS forge_migrations (
226 id SERIAL PRIMARY KEY,
227 name VARCHAR(255) UNIQUE NOT NULL,
228 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
229 down_sql TEXT
230 )
231 "#,
232 )
233 .execute(&self.pool)
234 .await
235 .map_err(|e| ForgeError::Database(format!("Failed to create migrations table: {}", e)))?;
236
237 sqlx::query(
239 r#"
240 ALTER TABLE forge_migrations
241 ADD COLUMN IF NOT EXISTS down_sql TEXT
242 "#,
243 )
244 .execute(&self.pool)
245 .await
246 .map_err(|e| ForgeError::Database(format!("Failed to add down_sql column: {}", e)))?;
247
248 Ok(())
249 }
250
251 async fn get_applied_migrations(&self) -> Result<HashSet<String>> {
252 let rows: Vec<(String,)> = sqlx::query_as("SELECT name FROM forge_migrations")
253 .fetch_all(&self.pool)
254 .await
255 .map_err(|e| {
256 ForgeError::Database(format!("Failed to get applied migrations: {}", e))
257 })?;
258
259 Ok(rows.into_iter().map(|(name,)| name).collect())
260 }
261
262 async fn apply_migration(&self, migration: &Migration) -> Result<()> {
263 info!("Applying migration: {}", migration.name);
264
265 let statements = split_sql_statements(&migration.up_sql);
267
268 for statement in statements {
269 let statement = statement.trim();
270
271 if statement.is_empty()
273 || statement.lines().all(|l| {
274 let l = l.trim();
275 l.is_empty() || l.starts_with("--")
276 })
277 {
278 continue;
279 }
280
281 sqlx::query(statement)
282 .execute(&self.pool)
283 .await
284 .map_err(|e| {
285 ForgeError::Database(format!(
286 "Failed to apply migration '{}': {}",
287 migration.name, e
288 ))
289 })?;
290 }
291
292 sqlx::query("INSERT INTO forge_migrations (name, down_sql) VALUES ($1, $2)")
294 .bind(&migration.name)
295 .bind(&migration.down_sql)
296 .execute(&self.pool)
297 .await
298 .map_err(|e| {
299 ForgeError::Database(format!(
300 "Failed to record migration '{}': {}",
301 migration.name, e
302 ))
303 })?;
304
305 info!("Migration applied: {}", migration.name);
306 Ok(())
307 }
308
309 pub async fn rollback(&self, count: usize) -> Result<Vec<String>> {
311 if count == 0 {
312 return Ok(Vec::new());
313 }
314
315 self.acquire_lock().await?;
317
318 let result = self.rollback_inner(count).await;
319
320 if let Err(e) = self.release_lock().await {
322 warn!("Failed to release migration lock: {}", e);
323 }
324
325 result
326 }
327
328 async fn rollback_inner(&self, count: usize) -> Result<Vec<String>> {
329 self.ensure_migrations_table().await?;
330
331 let rows: Vec<(i32, String, Option<String>)> = sqlx::query_as(
333 "SELECT id, name, down_sql FROM forge_migrations ORDER BY id DESC LIMIT $1",
334 )
335 .bind(count as i32)
336 .fetch_all(&self.pool)
337 .await
338 .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
339
340 if rows.is_empty() {
341 info!("No migrations to rollback");
342 return Ok(Vec::new());
343 }
344
345 let mut rolled_back = Vec::new();
346
347 for (id, name, down_sql) in rows {
348 info!("Rolling back migration: {}", name);
349
350 if let Some(down) = down_sql {
351 let statements = split_sql_statements(&down);
353 for statement in statements {
354 let statement = statement.trim();
355 if statement.is_empty()
356 || statement.lines().all(|l| {
357 let l = l.trim();
358 l.is_empty() || l.starts_with("--")
359 })
360 {
361 continue;
362 }
363
364 sqlx::query(statement)
365 .execute(&self.pool)
366 .await
367 .map_err(|e| {
368 ForgeError::Database(format!(
369 "Failed to rollback migration '{}': {}",
370 name, e
371 ))
372 })?;
373 }
374 } else {
375 warn!("Migration '{}' has no down SQL, removing record only", name);
376 }
377
378 sqlx::query("DELETE FROM forge_migrations WHERE id = $1")
380 .bind(id)
381 .execute(&self.pool)
382 .await
383 .map_err(|e| {
384 ForgeError::Database(format!(
385 "Failed to remove migration record '{}': {}",
386 name, e
387 ))
388 })?;
389
390 info!("Rolled back migration: {}", name);
391 rolled_back.push(name);
392 }
393
394 Ok(rolled_back)
395 }
396
397 pub async fn status(&self, available: &[Migration]) -> Result<MigrationStatus> {
399 self.ensure_migrations_table().await?;
400
401 let applied = self.get_applied_migrations().await?;
402
403 let applied_list: Vec<AppliedMigration> = {
404 let rows: Vec<(String, chrono::DateTime<chrono::Utc>, Option<String>)> =
405 sqlx::query_as(
406 "SELECT name, applied_at, down_sql FROM forge_migrations ORDER BY id ASC",
407 )
408 .fetch_all(&self.pool)
409 .await
410 .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
411
412 rows.into_iter()
413 .map(|(name, applied_at, down_sql)| AppliedMigration {
414 name,
415 applied_at,
416 has_down: down_sql.is_some(),
417 })
418 .collect()
419 };
420
421 let pending: Vec<String> = available
422 .iter()
423 .filter(|m| !applied.contains(&m.name))
424 .map(|m| m.name.clone())
425 .collect();
426
427 Ok(MigrationStatus {
428 applied: applied_list,
429 pending,
430 })
431 }
432}
433
434#[derive(Debug, Clone)]
436pub struct AppliedMigration {
437 pub name: String,
438 pub applied_at: chrono::DateTime<chrono::Utc>,
439 pub has_down: bool,
440}
441
442#[derive(Debug, Clone)]
444pub struct MigrationStatus {
445 pub applied: Vec<AppliedMigration>,
446 pub pending: Vec<String>,
447}
448
449fn split_sql_statements(sql: &str) -> Vec<String> {
452 let mut statements = Vec::new();
453 let mut current = String::new();
454 let mut in_dollar_quote = false;
455 let mut dollar_tag = String::new();
456 let mut chars = sql.chars().peekable();
457
458 while let Some(c) = chars.next() {
459 current.push(c);
460
461 if c == '$' {
463 let mut potential_tag = String::from("$");
465
466 while let Some(&next_c) = chars.peek() {
468 if next_c == '$' {
469 potential_tag.push(chars.next().unwrap());
470 current.push('$');
471 break;
472 } else if next_c.is_alphanumeric() || next_c == '_' {
473 potential_tag.push(chars.next().unwrap());
474 current.push(potential_tag.chars().last().unwrap());
475 } else {
476 break;
477 }
478 }
479
480 if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
482 if in_dollar_quote && potential_tag == dollar_tag {
483 in_dollar_quote = false;
485 dollar_tag.clear();
486 } else if !in_dollar_quote {
487 in_dollar_quote = true;
489 dollar_tag = potential_tag;
490 }
491 }
492 }
493
494 if c == ';' && !in_dollar_quote {
496 let stmt = current.trim().trim_end_matches(';').trim().to_string();
497 if !stmt.is_empty() {
498 statements.push(stmt);
499 }
500 current.clear();
501 }
502 }
503
504 let stmt = current.trim().trim_end_matches(';').trim().to_string();
506 if !stmt.is_empty() {
507 statements.push(stmt);
508 }
509
510 statements
511}
512
513pub fn load_migrations_from_dir(dir: &Path) -> Result<Vec<Migration>> {
521 if !dir.exists() {
522 debug!("Migrations directory does not exist: {:?}", dir);
523 return Ok(Vec::new());
524 }
525
526 let mut migrations = Vec::new();
527
528 let entries = std::fs::read_dir(dir).map_err(ForgeError::Io)?;
529
530 for entry in entries {
531 let entry = entry.map_err(ForgeError::Io)?;
532 let path = entry.path();
533
534 if path.extension().map(|e| e == "sql").unwrap_or(false) {
535 let name = path
536 .file_stem()
537 .and_then(|s| s.to_str())
538 .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
539 .to_string();
540
541 let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
542
543 migrations.push(Migration::parse(name, &content));
544 }
545 }
546
547 migrations.sort_by(|a, b| a.name.cmp(&b.name));
549
550 debug!("Loaded {} user migrations", migrations.len());
551 Ok(migrations)
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557 use std::fs;
558 use tempfile::TempDir;
559
560 #[test]
561 fn test_load_migrations_from_empty_dir() {
562 let dir = TempDir::new().unwrap();
563 let migrations = load_migrations_from_dir(dir.path()).unwrap();
564 assert!(migrations.is_empty());
565 }
566
567 #[test]
568 fn test_load_migrations_from_nonexistent_dir() {
569 let migrations = load_migrations_from_dir(Path::new("/nonexistent/path")).unwrap();
570 assert!(migrations.is_empty());
571 }
572
573 #[test]
574 fn test_load_migrations_sorted() {
575 let dir = TempDir::new().unwrap();
576
577 fs::write(dir.path().join("0002_second.sql"), "SELECT 2;").unwrap();
579 fs::write(dir.path().join("0001_first.sql"), "SELECT 1;").unwrap();
580 fs::write(dir.path().join("0003_third.sql"), "SELECT 3;").unwrap();
581
582 let migrations = load_migrations_from_dir(dir.path()).unwrap();
583 assert_eq!(migrations.len(), 3);
584 assert_eq!(migrations[0].name, "0001_first");
585 assert_eq!(migrations[1].name, "0002_second");
586 assert_eq!(migrations[2].name, "0003_third");
587 }
588
589 #[test]
590 fn test_load_migrations_ignores_non_sql() {
591 let dir = TempDir::new().unwrap();
592
593 fs::write(dir.path().join("0001_migration.sql"), "SELECT 1;").unwrap();
594 fs::write(dir.path().join("readme.txt"), "Not a migration").unwrap();
595 fs::write(dir.path().join("backup.sql.bak"), "Backup").unwrap();
596
597 let migrations = load_migrations_from_dir(dir.path()).unwrap();
598 assert_eq!(migrations.len(), 1);
599 assert_eq!(migrations[0].name, "0001_migration");
600 }
601
602 #[test]
603 fn test_migration_new() {
604 let m = Migration::new("test", "SELECT 1");
605 assert_eq!(m.name, "test");
606 assert_eq!(m.up_sql, "SELECT 1");
607 assert!(m.down_sql.is_none());
608 }
609
610 #[test]
611 fn test_migration_with_down() {
612 let m = Migration::with_down("test", "CREATE TABLE t()", "DROP TABLE t");
613 assert_eq!(m.name, "test");
614 assert_eq!(m.up_sql, "CREATE TABLE t()");
615 assert_eq!(m.down_sql, Some("DROP TABLE t".to_string()));
616 }
617
618 #[test]
619 fn test_migration_parse_up_only() {
620 let content = "CREATE TABLE users (id INT);";
621 let m = Migration::parse("0001_test", content);
622 assert_eq!(m.name, "0001_test");
623 assert_eq!(m.up_sql, "CREATE TABLE users (id INT);");
624 assert!(m.down_sql.is_none());
625 }
626
627 #[test]
628 fn test_migration_parse_with_markers() {
629 let content = r#"
630-- @up
631CREATE TABLE users (
632 id UUID PRIMARY KEY,
633 email VARCHAR(255)
634);
635
636-- @down
637DROP TABLE users;
638"#;
639 let m = Migration::parse("0001_users", content);
640 assert_eq!(m.name, "0001_users");
641 assert!(m.up_sql.contains("CREATE TABLE users"));
642 assert!(!m.up_sql.contains("@up"));
643 assert!(!m.up_sql.contains("DROP TABLE"));
644 assert_eq!(m.down_sql, Some("DROP TABLE users;".to_string()));
645 }
646
647 #[test]
648 fn test_migration_parse_complex() {
649 let content = r#"
650-- @up
651CREATE TABLE posts (
652 id UUID PRIMARY KEY,
653 title TEXT NOT NULL
654);
655CREATE INDEX idx_posts_title ON posts(title);
656
657-- @down
658DROP INDEX idx_posts_title;
659DROP TABLE posts;
660"#;
661 let m = Migration::parse("0002_posts", content);
662 assert!(m.up_sql.contains("CREATE TABLE posts"));
663 assert!(m.up_sql.contains("CREATE INDEX"));
664 let down = m.down_sql.unwrap();
665 assert!(down.contains("DROP INDEX"));
666 assert!(down.contains("DROP TABLE posts"));
667 }
668
669 #[test]
670 fn test_split_simple_statements() {
671 let sql = "SELECT 1; SELECT 2; SELECT 3;";
672 let stmts = super::split_sql_statements(sql);
673 assert_eq!(stmts.len(), 3);
674 assert_eq!(stmts[0], "SELECT 1");
675 assert_eq!(stmts[1], "SELECT 2");
676 assert_eq!(stmts[2], "SELECT 3");
677 }
678
679 #[test]
680 fn test_split_with_dollar_quoted_function() {
681 let sql = r#"
682CREATE FUNCTION test() RETURNS void AS $$
683BEGIN
684 SELECT 1;
685 SELECT 2;
686END;
687$$ LANGUAGE plpgsql;
688
689SELECT 3;
690"#;
691 let stmts = super::split_sql_statements(sql);
692 assert_eq!(stmts.len(), 2);
693 assert!(stmts[0].contains("CREATE FUNCTION"));
694 assert!(stmts[0].contains("$$ LANGUAGE plpgsql"));
695 assert!(stmts[1].contains("SELECT 3"));
696 }
697
698 #[test]
699 fn test_split_preserves_dollar_quote_content() {
700 let sql = r#"
701CREATE FUNCTION notify() RETURNS trigger AS $$
702DECLARE
703 row_id TEXT;
704BEGIN
705 row_id := NEW.id::TEXT;
706 RETURN NEW;
707END;
708$$ LANGUAGE plpgsql;
709"#;
710 let stmts = super::split_sql_statements(sql);
711 assert_eq!(stmts.len(), 1);
712 assert!(stmts[0].contains("row_id := NEW.id::TEXT"));
713 }
714}