forge_runtime/migrations/
runner.rs1use forge_core::error::{ForgeError, Result};
6use sqlx::PgPool;
7use std::collections::HashSet;
8use std::path::Path;
9use tracing::{debug, info, warn};
10
11const MIGRATION_LOCK_ID: i64 = 0x464F524745; #[derive(Debug, Clone)]
17pub struct Migration {
18 pub name: String,
20 pub up_sql: String,
22 pub down_sql: Option<String>,
24}
25
26impl Migration {
27 pub fn new(name: impl Into<String>, sql: impl Into<String>) -> Self {
29 Self {
30 name: name.into(),
31 up_sql: sql.into(),
32 down_sql: None,
33 }
34 }
35
36 pub fn with_down(
38 name: impl Into<String>,
39 up_sql: impl Into<String>,
40 down_sql: impl Into<String>,
41 ) -> Self {
42 Self {
43 name: name.into(),
44 up_sql: up_sql.into(),
45 down_sql: Some(down_sql.into()),
46 }
47 }
48
49 pub fn parse(name: impl Into<String>, content: &str) -> Self {
51 let name = name.into();
52 let (up_sql, down_sql) = parse_migration_content(content);
53 Self {
54 name,
55 up_sql,
56 down_sql,
57 }
58 }
59}
60
61fn parse_migration_content(content: &str) -> (String, Option<String>) {
64 let down_marker_patterns = ["-- @down", "--@down", "-- @DOWN", "--@DOWN"];
66
67 for pattern in down_marker_patterns {
68 if let Some(idx) = content.find(pattern) {
69 let up_part = &content[..idx];
70 let down_part = &content[idx + pattern.len()..];
71
72 let up_sql = up_part
74 .replace("-- @up", "")
75 .replace("--@up", "")
76 .replace("-- @UP", "")
77 .replace("--@UP", "")
78 .trim()
79 .to_string();
80
81 let down_sql = down_part.trim().to_string();
82
83 if down_sql.is_empty() {
84 return (up_sql, None);
85 }
86 return (up_sql, Some(down_sql));
87 }
88 }
89
90 let up_sql = content
92 .replace("-- @up", "")
93 .replace("--@up", "")
94 .replace("-- @UP", "")
95 .replace("--@UP", "")
96 .trim()
97 .to_string();
98
99 (up_sql, None)
100}
101
102pub struct MigrationRunner {
104 pool: PgPool,
105}
106
107impl MigrationRunner {
108 pub fn new(pool: PgPool) -> Self {
109 Self { pool }
110 }
111
112 pub async fn run(&self, user_migrations: Vec<Migration>) -> Result<()> {
117 self.acquire_lock().await?;
119
120 let result = self.run_migrations_inner(user_migrations).await;
121
122 if let Err(e) = self.release_lock().await {
124 warn!("Failed to release migration lock: {}", e);
125 }
126
127 result
128 }
129
130 async fn run_migrations_inner(&self, user_migrations: Vec<Migration>) -> Result<()> {
131 self.ensure_migrations_table().await?;
133
134 let applied = self.get_applied_migrations().await?;
136 debug!("Already applied migrations: {:?}", applied);
137
138 let builtin = super::builtin::get_builtin_migrations();
140 for migration in builtin {
141 if !applied.contains(&migration.name) {
142 self.apply_migration(&migration).await?;
143 }
144 }
145
146 for migration in user_migrations {
148 if !applied.contains(&migration.name) {
149 self.apply_migration(&migration).await?;
150 }
151 }
152
153 Ok(())
154 }
155
156 async fn acquire_lock(&self) -> Result<()> {
157 debug!("Acquiring migration lock...");
158 sqlx::query("SELECT pg_advisory_lock($1)")
159 .bind(MIGRATION_LOCK_ID)
160 .execute(&self.pool)
161 .await
162 .map_err(|e| {
163 ForgeError::Database(format!("Failed to acquire migration lock: {}", e))
164 })?;
165 debug!("Migration lock acquired");
166 Ok(())
167 }
168
169 async fn release_lock(&self) -> Result<()> {
170 sqlx::query("SELECT pg_advisory_unlock($1)")
171 .bind(MIGRATION_LOCK_ID)
172 .execute(&self.pool)
173 .await
174 .map_err(|e| {
175 ForgeError::Database(format!("Failed to release migration lock: {}", e))
176 })?;
177 debug!("Migration lock released");
178 Ok(())
179 }
180
181 async fn ensure_migrations_table(&self) -> Result<()> {
182 sqlx::query(
184 r#"
185 CREATE TABLE IF NOT EXISTS forge_migrations (
186 id SERIAL PRIMARY KEY,
187 name VARCHAR(255) UNIQUE NOT NULL,
188 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
189 down_sql TEXT
190 )
191 "#,
192 )
193 .execute(&self.pool)
194 .await
195 .map_err(|e| ForgeError::Database(format!("Failed to create migrations table: {}", e)))?;
196
197 sqlx::query(
199 r#"
200 ALTER TABLE forge_migrations
201 ADD COLUMN IF NOT EXISTS down_sql TEXT
202 "#,
203 )
204 .execute(&self.pool)
205 .await
206 .map_err(|e| ForgeError::Database(format!("Failed to add down_sql column: {}", e)))?;
207
208 Ok(())
209 }
210
211 async fn get_applied_migrations(&self) -> Result<HashSet<String>> {
212 let rows: Vec<(String,)> = sqlx::query_as("SELECT name FROM forge_migrations")
213 .fetch_all(&self.pool)
214 .await
215 .map_err(|e| {
216 ForgeError::Database(format!("Failed to get applied migrations: {}", e))
217 })?;
218
219 Ok(rows.into_iter().map(|(name,)| name).collect())
220 }
221
222 async fn apply_migration(&self, migration: &Migration) -> Result<()> {
223 info!("Applying migration: {}", migration.name);
224
225 let statements = split_sql_statements(&migration.up_sql);
227
228 for statement in statements {
229 let statement = statement.trim();
230
231 if statement.is_empty()
233 || statement.lines().all(|l| {
234 let l = l.trim();
235 l.is_empty() || l.starts_with("--")
236 })
237 {
238 continue;
239 }
240
241 sqlx::query(statement)
242 .execute(&self.pool)
243 .await
244 .map_err(|e| {
245 ForgeError::Database(format!(
246 "Failed to apply migration '{}': {}",
247 migration.name, e
248 ))
249 })?;
250 }
251
252 sqlx::query("INSERT INTO forge_migrations (name, down_sql) VALUES ($1, $2)")
254 .bind(&migration.name)
255 .bind(&migration.down_sql)
256 .execute(&self.pool)
257 .await
258 .map_err(|e| {
259 ForgeError::Database(format!(
260 "Failed to record migration '{}': {}",
261 migration.name, e
262 ))
263 })?;
264
265 info!("Migration applied: {}", migration.name);
266 Ok(())
267 }
268
269 pub async fn rollback(&self, count: usize) -> Result<Vec<String>> {
271 if count == 0 {
272 return Ok(Vec::new());
273 }
274
275 self.acquire_lock().await?;
277
278 let result = self.rollback_inner(count).await;
279
280 if let Err(e) = self.release_lock().await {
282 warn!("Failed to release migration lock: {}", e);
283 }
284
285 result
286 }
287
288 async fn rollback_inner(&self, count: usize) -> Result<Vec<String>> {
289 self.ensure_migrations_table().await?;
290
291 let rows: Vec<(i32, String, Option<String>)> = sqlx::query_as(
293 "SELECT id, name, down_sql FROM forge_migrations ORDER BY id DESC LIMIT $1",
294 )
295 .bind(count as i32)
296 .fetch_all(&self.pool)
297 .await
298 .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
299
300 if rows.is_empty() {
301 info!("No migrations to rollback");
302 return Ok(Vec::new());
303 }
304
305 let mut rolled_back = Vec::new();
306
307 for (id, name, down_sql) in rows {
308 info!("Rolling back migration: {}", name);
309
310 if let Some(down) = down_sql {
311 let statements = split_sql_statements(&down);
313 for statement in statements {
314 let statement = statement.trim();
315 if statement.is_empty()
316 || statement.lines().all(|l| {
317 let l = l.trim();
318 l.is_empty() || l.starts_with("--")
319 })
320 {
321 continue;
322 }
323
324 sqlx::query(statement)
325 .execute(&self.pool)
326 .await
327 .map_err(|e| {
328 ForgeError::Database(format!(
329 "Failed to rollback migration '{}': {}",
330 name, e
331 ))
332 })?;
333 }
334 } else {
335 warn!("Migration '{}' has no down SQL, removing record only", name);
336 }
337
338 sqlx::query("DELETE FROM forge_migrations WHERE id = $1")
340 .bind(id)
341 .execute(&self.pool)
342 .await
343 .map_err(|e| {
344 ForgeError::Database(format!(
345 "Failed to remove migration record '{}': {}",
346 name, e
347 ))
348 })?;
349
350 info!("Rolled back migration: {}", name);
351 rolled_back.push(name);
352 }
353
354 Ok(rolled_back)
355 }
356
357 pub async fn status(&self, available: &[Migration]) -> Result<MigrationStatus> {
359 self.ensure_migrations_table().await?;
360
361 let applied = self.get_applied_migrations().await?;
362
363 let applied_list: Vec<AppliedMigration> = {
364 let rows: Vec<(String, chrono::DateTime<chrono::Utc>, Option<String>)> =
365 sqlx::query_as(
366 "SELECT name, applied_at, down_sql FROM forge_migrations ORDER BY id ASC",
367 )
368 .fetch_all(&self.pool)
369 .await
370 .map_err(|e| ForgeError::Database(format!("Failed to get migrations: {}", e)))?;
371
372 rows.into_iter()
373 .map(|(name, applied_at, down_sql)| AppliedMigration {
374 name,
375 applied_at,
376 has_down: down_sql.is_some(),
377 })
378 .collect()
379 };
380
381 let pending: Vec<String> = available
382 .iter()
383 .filter(|m| !applied.contains(&m.name))
384 .map(|m| m.name.clone())
385 .collect();
386
387 Ok(MigrationStatus {
388 applied: applied_list,
389 pending,
390 })
391 }
392}
393
394#[derive(Debug, Clone)]
396pub struct AppliedMigration {
397 pub name: String,
398 pub applied_at: chrono::DateTime<chrono::Utc>,
399 pub has_down: bool,
400}
401
402#[derive(Debug, Clone)]
404pub struct MigrationStatus {
405 pub applied: Vec<AppliedMigration>,
406 pub pending: Vec<String>,
407}
408
409fn split_sql_statements(sql: &str) -> Vec<String> {
412 let mut statements = Vec::new();
413 let mut current = String::new();
414 let mut in_dollar_quote = false;
415 let mut dollar_tag = String::new();
416 let mut chars = sql.chars().peekable();
417
418 while let Some(c) = chars.next() {
419 current.push(c);
420
421 if c == '$' {
423 let mut potential_tag = String::from("$");
425
426 while let Some(&next_c) = chars.peek() {
428 if next_c == '$' {
429 potential_tag.push(chars.next().unwrap());
430 current.push('$');
431 break;
432 } else if next_c.is_alphanumeric() || next_c == '_' {
433 potential_tag.push(chars.next().unwrap());
434 current.push(potential_tag.chars().last().unwrap());
435 } else {
436 break;
437 }
438 }
439
440 if potential_tag.len() >= 2 && potential_tag.ends_with('$') {
442 if in_dollar_quote && potential_tag == dollar_tag {
443 in_dollar_quote = false;
445 dollar_tag.clear();
446 } else if !in_dollar_quote {
447 in_dollar_quote = true;
449 dollar_tag = potential_tag;
450 }
451 }
452 }
453
454 if c == ';' && !in_dollar_quote {
456 let stmt = current.trim().trim_end_matches(';').trim().to_string();
457 if !stmt.is_empty() {
458 statements.push(stmt);
459 }
460 current.clear();
461 }
462 }
463
464 let stmt = current.trim().trim_end_matches(';').trim().to_string();
466 if !stmt.is_empty() {
467 statements.push(stmt);
468 }
469
470 statements
471}
472
473pub fn load_migrations_from_dir(dir: &Path) -> Result<Vec<Migration>> {
481 if !dir.exists() {
482 debug!("Migrations directory does not exist: {:?}", dir);
483 return Ok(Vec::new());
484 }
485
486 let mut migrations = Vec::new();
487
488 let entries = std::fs::read_dir(dir).map_err(ForgeError::Io)?;
489
490 for entry in entries {
491 let entry = entry.map_err(ForgeError::Io)?;
492 let path = entry.path();
493
494 if path.extension().map(|e| e == "sql").unwrap_or(false) {
495 let name = path
496 .file_stem()
497 .and_then(|s| s.to_str())
498 .ok_or_else(|| ForgeError::Config("Invalid migration filename".into()))?
499 .to_string();
500
501 let content = std::fs::read_to_string(&path).map_err(ForgeError::Io)?;
502
503 migrations.push(Migration::parse(name, &content));
504 }
505 }
506
507 migrations.sort_by(|a, b| a.name.cmp(&b.name));
509
510 debug!("Loaded {} user migrations", migrations.len());
511 Ok(migrations)
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517 use std::fs;
518 use tempfile::TempDir;
519
520 #[test]
521 fn test_load_migrations_from_empty_dir() {
522 let dir = TempDir::new().unwrap();
523 let migrations = load_migrations_from_dir(dir.path()).unwrap();
524 assert!(migrations.is_empty());
525 }
526
527 #[test]
528 fn test_load_migrations_from_nonexistent_dir() {
529 let migrations = load_migrations_from_dir(Path::new("/nonexistent/path")).unwrap();
530 assert!(migrations.is_empty());
531 }
532
533 #[test]
534 fn test_load_migrations_sorted() {
535 let dir = TempDir::new().unwrap();
536
537 fs::write(dir.path().join("0002_second.sql"), "SELECT 2;").unwrap();
539 fs::write(dir.path().join("0001_first.sql"), "SELECT 1;").unwrap();
540 fs::write(dir.path().join("0003_third.sql"), "SELECT 3;").unwrap();
541
542 let migrations = load_migrations_from_dir(dir.path()).unwrap();
543 assert_eq!(migrations.len(), 3);
544 assert_eq!(migrations[0].name, "0001_first");
545 assert_eq!(migrations[1].name, "0002_second");
546 assert_eq!(migrations[2].name, "0003_third");
547 }
548
549 #[test]
550 fn test_load_migrations_ignores_non_sql() {
551 let dir = TempDir::new().unwrap();
552
553 fs::write(dir.path().join("0001_migration.sql"), "SELECT 1;").unwrap();
554 fs::write(dir.path().join("readme.txt"), "Not a migration").unwrap();
555 fs::write(dir.path().join("backup.sql.bak"), "Backup").unwrap();
556
557 let migrations = load_migrations_from_dir(dir.path()).unwrap();
558 assert_eq!(migrations.len(), 1);
559 assert_eq!(migrations[0].name, "0001_migration");
560 }
561
562 #[test]
563 fn test_migration_new() {
564 let m = Migration::new("test", "SELECT 1");
565 assert_eq!(m.name, "test");
566 assert_eq!(m.up_sql, "SELECT 1");
567 assert!(m.down_sql.is_none());
568 }
569
570 #[test]
571 fn test_migration_with_down() {
572 let m = Migration::with_down("test", "CREATE TABLE t()", "DROP TABLE t");
573 assert_eq!(m.name, "test");
574 assert_eq!(m.up_sql, "CREATE TABLE t()");
575 assert_eq!(m.down_sql, Some("DROP TABLE t".to_string()));
576 }
577
578 #[test]
579 fn test_migration_parse_up_only() {
580 let content = "CREATE TABLE users (id INT);";
581 let m = Migration::parse("0001_test", content);
582 assert_eq!(m.name, "0001_test");
583 assert_eq!(m.up_sql, "CREATE TABLE users (id INT);");
584 assert!(m.down_sql.is_none());
585 }
586
587 #[test]
588 fn test_migration_parse_with_markers() {
589 let content = r#"
590-- @up
591CREATE TABLE users (
592 id UUID PRIMARY KEY,
593 email VARCHAR(255)
594);
595
596-- @down
597DROP TABLE users;
598"#;
599 let m = Migration::parse("0001_users", content);
600 assert_eq!(m.name, "0001_users");
601 assert!(m.up_sql.contains("CREATE TABLE users"));
602 assert!(!m.up_sql.contains("@up"));
603 assert!(!m.up_sql.contains("DROP TABLE"));
604 assert_eq!(m.down_sql, Some("DROP TABLE users;".to_string()));
605 }
606
607 #[test]
608 fn test_migration_parse_complex() {
609 let content = r#"
610-- @up
611CREATE TABLE posts (
612 id UUID PRIMARY KEY,
613 title TEXT NOT NULL
614);
615CREATE INDEX idx_posts_title ON posts(title);
616
617-- @down
618DROP INDEX idx_posts_title;
619DROP TABLE posts;
620"#;
621 let m = Migration::parse("0002_posts", content);
622 assert!(m.up_sql.contains("CREATE TABLE posts"));
623 assert!(m.up_sql.contains("CREATE INDEX"));
624 let down = m.down_sql.unwrap();
625 assert!(down.contains("DROP INDEX"));
626 assert!(down.contains("DROP TABLE posts"));
627 }
628
629 #[test]
630 fn test_split_simple_statements() {
631 let sql = "SELECT 1; SELECT 2; SELECT 3;";
632 let stmts = super::split_sql_statements(sql);
633 assert_eq!(stmts.len(), 3);
634 assert_eq!(stmts[0], "SELECT 1");
635 assert_eq!(stmts[1], "SELECT 2");
636 assert_eq!(stmts[2], "SELECT 3");
637 }
638
639 #[test]
640 fn test_split_with_dollar_quoted_function() {
641 let sql = r#"
642CREATE FUNCTION test() RETURNS void AS $$
643BEGIN
644 SELECT 1;
645 SELECT 2;
646END;
647$$ LANGUAGE plpgsql;
648
649SELECT 3;
650"#;
651 let stmts = super::split_sql_statements(sql);
652 assert_eq!(stmts.len(), 2);
653 assert!(stmts[0].contains("CREATE FUNCTION"));
654 assert!(stmts[0].contains("$$ LANGUAGE plpgsql"));
655 assert!(stmts[1].contains("SELECT 3"));
656 }
657
658 #[test]
659 fn test_split_preserves_dollar_quote_content() {
660 let sql = r#"
661CREATE FUNCTION notify() RETURNS trigger AS $$
662DECLARE
663 row_id TEXT;
664BEGIN
665 row_id := NEW.id::TEXT;
666 RETURN NEW;
667END;
668$$ LANGUAGE plpgsql;
669"#;
670 let stmts = super::split_sql_statements(sql);
671 assert_eq!(stmts.len(), 1);
672 assert!(stmts[0].contains("row_id := NEW.id::TEXT"));
673 }
674}