forge_runtime/migrations/
runner.rs1use forge_core::error::{ForgeError, Result};
17use sqlx::{PgPool, Postgres};
18use std::collections::HashSet;
19use std::path::Path;
20use tracing::{debug, info, warn};
21
22use super::builtin::extract_version;
23
24const MIGRATION_LOCK_ID: i64 = 0x464F524745; #[derive(Debug, Clone)]
30pub struct Migration {
31 pub name: String,
33 pub up_sql: String,
35 pub down_sql: Option<String>,
37}
38
39impl Migration {
40 pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
42 Self {
43 name: name.into(),
44 up_sql: sql.into(),
45 down_sql: None,
46 }
47 }
48
49 pub fn with_down(
51 name: impl Into<String>,
52 up_sql: impl Into<String>,
53 down_sql: impl Into<String>,
54 ) -> Self {
55 Self {
56 name: name.into(),
57 up_sql: up_sql.into(),
58 down_sql: Some(down_sql.into()),
59 }
60 }
61
62 pub fn parse(name: impl Into<String>, content: &str) -> Self {
64 let name = name.into();
65 let (up_sql, down_sql) = parse_migration_content(content);
66 Self {
67 name,
68 up_sql,
69 down_sql,
70 }
71 }
72}
73
74fn parse_migration_content(content: &str) -> (String, Option<String>) {
77 let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
79
80 for pattern in down_marker_patterns {
81 if let Some(idx) = content.find(pattern) {
82 let up_part = &content[..idx];
83 let down_part = &content[idx + pattern.len()..];
84
85 let up_sql = up_part
87 .replace("-- @up", "")
88 .replace("--@up", "")
89 .replace("-- @UP", "")
90 .replace("--@UP", "")
91 .trim()
92 .to_string();
93
94 let down_sql = down_part.trim().to_string();
95
96 if down_sql.is_empty() {
97 return (up_sql, None);
98 }
99 return (up_sql, Some(down_sql));
100 }
101 }
102
103 let up_sql = content
105 .replace("-- @up", "")
106 .replace("--@up", "")
107 .replace("-- @UP", "")
108 .replace("--@UP", "")
109 .trim()
110 .to_string();
111
112 (up_sql, None)
113}
114
115pub struct MigrationRunner {
117 pool: PgPool,
118}
119
120impl MigrationRunner {
121 pub fn new(pool: PgPool) -> Self {
122 Self { pool }
123 }
124
125 pub async fn run(&self, user_migrations: Vec<Migration>) -> Result<()> {
130 let mut lock_conn = self.acquire_lock_connection().await?;
132
133 let result = self.run_migrations_inner(user_migrations).await;
134
135 if let Err(e) = self.release_lock_connection(&mut lock_conn).await {
137 warn!("Failed to release migration lock: {}", e);
138 }
139
140 result
141 }
142
143 async fn run_migrations_inner(&self, user_migrations: Vec<Migration>) -> Result<()> {
144 self.ensure_migrations_table().await?;
146
147 let applied = self.get_applied_migrations().await?;
149 debug!("Already applied migrations: {:?}", applied);
150
151 let max_applied_version = self.get_max_system_version(&applied);
153 debug!("Max applied system version: {:?}", max_applied_version);
154
155 let system_migrations = super::builtin::get_system_migrations();
157 for sys_migration in system_migrations {
158 if let Some(max_ver) = max_applied_version
160 && sys_migration.version <= max_ver
161 {
162 debug!(
163 "Skipping system migration v{} (already at v{})",
164 sys_migration.version, max_ver
165 );
166 continue;
167 }
168
169 let migration = sys_migration.to_migration();
170 info!(
171 "Applying system migration: {} ({})",
172 migration.name, sys_migration.description
173 );
174 self.apply_migration(&migration).await?;
175 }
176
177 for migration in user_migrations {
179 if !applied.contains(&migration.name) {
180 self.apply_migration(&migration).await?;
181 }
182 }
183
184 Ok(())
185 }
186
187 fn get_max_system_version(&self, applied: &HashSet<String>) -> Option<u32> {
190 applied
191 .iter()
192 .filter_map(|name| extract_version(name))
193 .max()
194 }
195
196 async fn acquire_lock_connection(&self) -> Result<sqlx::pool::PoolConnection<Postgres>> {
197 debug!("Acquiring migration lock...");
198 let mut conn = self.pool.acquire().await.map_err(|e| {
199 ForgeError::Database(format!("Failed to acquire lock connection: {}", e))
200 })?;
201
202 sqlx::query("SELECT pg_advisory_lock($1)")
203 .bind(MIGRATION_LOCK_ID)
204 .execute(&mut *conn)
205 .await
206 .map_err(|e| {
207 ForgeError::Database(format!("Failed to acquire migration lock: {}", e))
208 })?;
209 debug!("Migration lock acquired");
210 Ok(conn)
211 }
212
213 async fn release_lock_connection(
214 &self,
215 conn: &mut sqlx::pool::PoolConnection<Postgres>,
216 ) -> Result<()> {
217 sqlx::query("SELECT pg_advisory_unlock($1)")
218 .bind(MIGRATION_LOCK_ID)
219 .execute(&mut **conn)
220 .await
221 .map_err(|e| {
222 ForgeError::Database(format!("Failed to release migration lock: {}", e))
223 })?;
224 debug!("Migration lock released");
225 Ok(())
226 }
227
228 async fn ensure_migrations_table(&self) -> Result<()> {
229 sqlx::query(
231 r#"
232 CREATE TABLE IF NOT EXISTS forge_migrations (
233 id SERIAL PRIMARY KEY,
234 name VARCHAR(255) UNIQUE NOT NULL,
235 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
236 down_sql TEXT
237 )
238 "#,
239 )
240 .execute(&self.pool)
241 .await
242 .map_err(|e| ForgeError::Database(format!("Failed to create migrations table: {}", e)))?;
243
244 sqlx::query(
246 r#"
247 ALTER TABLE forge_migrations
248 ADD COLUMN IF NOT EXISTS down_sql TEXT
249 "#,
250 )
251 .execute(&self.pool)
252 .await
253 .map_err(|e| ForgeError::Database(format!("Failed to add down_sql column: {}", e)))?;
254
255 Ok(())
256 }
257
258 async fn get_applied_migrations(&self) -> Result<HashSet<String>> {
259 let rows: Vec<(String,)> = sqlx::query_as("SELECT name FROM forge_migrations")
260 .fetch_all(&self.pool)
261 .await
262 .map_err(|e| {
263 ForgeError::Database(format!("Failed to get applied migrations: {}", e))
264 })?;
265
266 Ok(rows.into_iter().map(|(name,)| name).collect())
267 }
268
269 async fn apply_migration(&self, migration: &Migration) -> Result<()> {
270 info!("Applying migration: {}", migration.name);
271
272 let statements = split_sql_statements(&migration.up_sql);
274
275 for statement in statements {
276 let statement = statement.trim();
277
278 if statement.is_empty()
280 || statement.lines().all(|l| {
281 let l = l.trim();
282 l.is_empty() || l.starts_with("--")
283 })
284 {
285 continue;
286 }
287
288 sqlx::query(statement)
289 .execute(&self.pool)
290 .await
291 .map_err(|e| {
292 ForgeError::Database(format!(
293 "Failed to apply migration '{}': {}",
294 migration.name, e
295 ))
296 })?;
297 }
298
299 sqlx::query("INSERT INTO forge_migrations (name, down_sql) VALUES ($1, $2)")
301 .bind(&migration.name)
302 .bind(&migration.down_sql)
303 .execute(&self.pool)
304 .await
305 .map_err(|e| {
306 ForgeError::Database(format!(
307 "Failed to record migration '{}': {}",
308 migration.name, e
309 ))
310 })?;
311
312 info!("Migration applied: {}", migration.name);
313 Ok(())
314 }
315
316 pub async fn rollback(&self, count: usize) -> Result<Vec<String>> {
318 if count == 0 {
319 return Ok(Vec::new());
320 }
321
322 let mut lock_conn = self.acquire_lock_connection().await?;
324
325 let result = self.rollback_inner(count).await;
326
327 if let Err(e) = self.release_lock_connection(&mut lock_conn).await {
329 warn!("Failed to release migration lock: {}", e);
330 }
331
332 result
333 }
334
335 async fn rollback_inner(&self, count: usize) -> Result<Vec<String>> {
336 self.ensure_migrations_table().await?;
337
338 let rows: Vec<(i32, String, Option<String>)> = sqlx::query_as(
340 "SELECT id, name, down_sql FROM forge_migrations ORDER BY id DESC LIMIT $1",
341 )
342 .bind(count as i32)
343 .fetch_all(&self.pool)
344 .await
345 .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
346
347 if rows.is_empty() {
348 info!("No migrations to rollback");
349 return Ok(Vec::new());
350 }
351
352 let mut rolled_back = Vec::new();
353
354 for (id, name, down_sql) in rows {
355 info!("Rolling back migration: {}", name);
356
357 if let Some(down) = down_sql {
358 let statements = split_sql_statements(&down);
360 for statement in statements {
361 let statement = statement.trim();
362 if statement.is_empty()
363 || statement.lines().all(|l| {
364 let l = l.trim();
365 l.is_empty() || l.starts_with("--")
366 })
367 {
368 continue;
369 }
370
371 sqlx::query(statement)
372 .execute(&self.pool)
373 .await
374 .map_err(|e| {
375 ForgeError::Database(format!(
376 "Failed to rollback migration '{}': {}",
377 name, e
378 ))
379 })?;
380 }
381 } else {
382 warn!("Migration '{}' has no down SQL, removing record only", name);
383 }
384
385 sqlx::query("DELETE FROM forge_migrations WHERE id = $1")
387 .bind(id)
388 .execute(&self.pool)
389 .await
390 .map_err(|e| {
391 ForgeError::Database(format!(
392 "Failed to remove migration record '{}': {}",
393 name, e
394 ))
395 })?;
396
397 info!("Rolled back migration: {}", name);
398 rolled_back.push(name);
399 }
400
401 Ok(rolled_back)
402 }
403
404 pub async fn status(&self, available: &[Migration]) -> Result<MigrationStatus> {
406 self.ensure_migrations_table().await?;
407
408 let applied = self.get_applied_migrations().await?;
409
410 let applied_list: Vec<AppliedMigration> = {
411 let rows: Vec<(String, chrono::DateTime<chrono::Utc>, Option<String>)> =
412 sqlx::query_as(
413 "SELECT name, applied_at, down_sql FROM forge_migrations ORDER BY id ASC",
414 )
415 .fetch_all(&self.pool)
416 .await
417 .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
418
419 rows.into_iter()
420 .map(|(name, applied_at, down_sql)| AppliedMigration {
421 name,
422 applied_at,
423 has_down: down_sql.is_some(),
424 })
425 .collect()
426 };
427
428 let pending: Vec<String> = available
429 .iter()
430 .filter(|m| !applied.contains(&m.name))
431 .map(|m| m.name.clone())
432 .collect();
433
434 Ok(MigrationStatus {
435 applied: applied_list,
436 pending,
437 })
438 }
439}
440
441#[derive(Debug, Clone)]
443pub struct AppliedMigration {
444 pub name: String,
445 pub applied_at: chrono::DateTime<chrono::Utc>,
446 pub has_down: bool,
447}
448
449#[derive(Debug, Clone)]
451pub struct MigrationStatus {
452 pub applied: Vec<AppliedMigration>,
453 pub pending: Vec<String>,
454}
455
456fn split_sql_statements(sql: &str) -> Vec<String> {
459 let mut statements = Vec::new();
460 let mut current = String::new();
461 let mut in_dollar_quote = false;
462 let mut dollar_tag = String::new();
463 let mut chars = sql.chars().peekable();
464
465 while let Some(c) = chars.next() {
466 current.push(c);
467
468 if c == '$' {
470 let mut potential_tag = String::from("$");
472
473 while let Some(&next_c) = chars.peek() {
475 if next_c == '$' {
476 potential_tag.push(chars.next().expect("peeked char"));
478 current.push('$');
479 break;
480 } else if next_c.is_alphanumeric() || next_c == '_' {
481 let c = chars.next().expect("peeked char");
482 potential_tag.push(c);
483 current.push(c);
484 } else {
485 break;
486 }
487 }
488
489 if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
491 if in_dollar_quote && potential_tag == dollar_tag {
492 in_dollar_quote = false;
494 dollar_tag.clear();
495 } else if !in_dollar_quote {
496 in_dollar_quote = true;
498 dollar_tag = potential_tag;
499 }
500 }
501 }
502
503 if c == ';' && !in_dollar_quote {
505 let stmt = current.trim().trim_end_matches(';').trim().to_string();
506 if !stmt.is_empty() {
507 statements.push(stmt);
508 }
509 current.clear();
510 }
511 }
512
513 let stmt = current.trim().trim_end_matches(';').trim().to_string();
515 if !stmt.is_empty() {
516 statements.push(stmt);
517 }
518
519 statements
520}
521
522pub fn load_migrations_from_dir(dir: &Path) -> Result<Vec<Migration>> {
530 if !dir.exists() {
531 debug!("Migrations directory does not exist: {:?}", dir);
532 return Ok(Vec::new());
533 }
534
535 let mut migrations = Vec::new();
536
537 let entries = std::fs::read_dir(dir).map_err(ForgeError::Io)?;
538
539 for entry in entries {
540 let entry = entry.map_err(ForgeError::Io)?;
541 let path = entry.path();
542
543 if path.extension().map(|e| e == "sql").unwrap_or(false) {
544 let name = path
545 .file_stem()
546 .and_then(|s| s.to_str())
547 .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
548 .to_string();
549
550 let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
551
552 migrations.push(Migration::parse(name, &content));
553 }
554 }
555
556 migrations.sort_by(|a, b| a.name.cmp(&b.name));
558
559 debug!("Loaded {} user migrations", migrations.len());
560 Ok(migrations)
561}
562
563#[cfg(test)]
564#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
565mod tests {
566 use super::*;
567 use std::fs;
568 use tempfile::TempDir;
569
570 #[test]
571 fn test_load_migrations_from_empty_dir() {
572 let dir = TempDir::new().unwrap();
573 let migrations = load_migrations_from_dir(dir.path()).unwrap();
574 assert!(migrations.is_empty());
575 }
576
577 #[test]
578 fn test_load_migrations_from_nonexistent_dir() {
579 let migrations = load_migrations_from_dir(Path::new("/nonexistent/path")).unwrap();
580 assert!(migrations.is_empty());
581 }
582
583 #[test]
584 fn test_load_migrations_sorted() {
585 let dir = TempDir::new().unwrap();
586
587 fs::write(dir.path().join("0002_second.sql"), "SELECT 2;").unwrap();
589 fs::write(dir.path().join("0001_first.sql"), "SELECT 1;").unwrap();
590 fs::write(dir.path().join("0003_third.sql"), "SELECT 3;").unwrap();
591
592 let migrations = load_migrations_from_dir(dir.path()).unwrap();
593 assert_eq!(migrations.len(), 3);
594 assert_eq!(migrations[0].name, "0001_first");
595 assert_eq!(migrations[1].name, "0002_second");
596 assert_eq!(migrations[2].name, "0003_third");
597 }
598
599 #[test]
600 fn test_load_migrations_ignores_non_sql() {
601 let dir = TempDir::new().unwrap();
602
603 fs::write(dir.path().join("0001_migration.sql"), "SELECT 1;").unwrap();
604 fs::write(dir.path().join("readme.txt"), "Not a migration").unwrap();
605 fs::write(dir.path().join("backup.sql.bak"), "Backup").unwrap();
606
607 let migrations = load_migrations_from_dir(dir.path()).unwrap();
608 assert_eq!(migrations.len(), 1);
609 assert_eq!(migrations[0].name, "0001_migration");
610 }
611
612 #[test]
613 fn test_migration_new() {
614 let m = Migration::new("test", "SELECT 1");
615 assert_eq!(m.name, "test");
616 assert_eq!(m.up_sql, "SELECT 1");
617 assert!(m.down_sql.is_none());
618 }
619
620 #[test]
621 fn test_migration_with_down() {
622 let m = Migration::with_down("test", "CREATE TABLE t()", "DROP TABLE t");
623 assert_eq!(m.name, "test");
624 assert_eq!(m.up_sql, "CREATE TABLE t()");
625 assert_eq!(m.down_sql, Some("DROP TABLE t".to_string()));
626 }
627
628 #[test]
629 fn test_migration_parse_up_only() {
630 let content = "CREATE TABLE users (id INT);";
631 let m = Migration::parse("0001_test", content);
632 assert_eq!(m.name, "0001_test");
633 assert_eq!(m.up_sql, "CREATE TABLE users (id INT);");
634 assert!(m.down_sql.is_none());
635 }
636
637 #[test]
638 fn test_migration_parse_with_markers() {
639 let content = r#"
640-- @up
641CREATE TABLE users (
642 id UUID PRIMARY KEY,
643 email VARCHAR(255)
644);
645
646-- @down
647DROP TABLE users;
648"#;
649 let m = Migration::parse("0001_users", content);
650 assert_eq!(m.name, "0001_users");
651 assert!(m.up_sql.contains("CREATE TABLE users"));
652 assert!(!m.up_sql.contains("@up"));
653 assert!(!m.up_sql.contains("DROP TABLE"));
654 assert_eq!(m.down_sql, Some("DROP TABLE users;".to_string()));
655 }
656
657 #[test]
658 fn test_migration_parse_complex() {
659 let content = r#"
660-- @up
661CREATE TABLE posts (
662 id UUID PRIMARY KEY,
663 title TEXT NOT NULL
664);
665CREATE INDEX idx_posts_title ON posts(title);
666
667-- @down
668DROP INDEX idx_posts_title;
669DROP TABLE posts;
670"#;
671 let m = Migration::parse("0002_posts", content);
672 assert!(m.up_sql.contains("CREATE TABLE posts"));
673 assert!(m.up_sql.contains("CREATE INDEX"));
674 let down = m.down_sql.unwrap();
675 assert!(down.contains("DROP INDEX"));
676 assert!(down.contains("DROP TABLE posts"));
677 }
678
679 #[test]
680 fn test_split_simple_statements() {
681 let sql = "SELECT 1; SELECT 2; SELECT 3;";
682 let stmts = super::split_sql_statements(sql);
683 assert_eq!(stmts.len(), 3);
684 assert_eq!(stmts[0], "SELECT 1");
685 assert_eq!(stmts[1], "SELECT 2");
686 assert_eq!(stmts[2], "SELECT 3");
687 }
688
689 #[test]
690 fn test_split_with_dollar_quoted_function() {
691 let sql = r#"
692CREATE FUNCTION test() RETURNS void AS $$
693BEGIN
694 SELECT 1;
695 SELECT 2;
696END;
697$$ LANGUAGE plpgsql;
698
699SELECT 3;
700"#;
701 let stmts = super::split_sql_statements(sql);
702 assert_eq!(stmts.len(), 2);
703 assert!(stmts[0].contains("CREATE FUNCTION"));
704 assert!(stmts[0].contains("$$ LANGUAGE plpgsql"));
705 assert!(stmts[1].contains("SELECT 3"));
706 }
707
708 #[test]
709 fn test_split_preserves_dollar_quote_content() {
710 let sql = r#"
711CREATE FUNCTION notify() RETURNS trigger AS $$
712DECLARE
713 row_id TEXT;
714BEGIN
715 row_id := NEW.id::TEXT;
716 RETURN NEW;
717END;
718$$ LANGUAGE plpgsql;
719"#;
720 let stmts = super::split_sql_statements(sql);
721 assert_eq!(stmts.len(), 1);
722 assert!(stmts[0].contains("row_id := NEW.id::TEXT"));
723 }
724}