Skip to main content

forge_runtime/migrations/
executor.rs

1use chrono::{DateTime, Utc};
2use sqlx::{PgPool, Row};
3
4use super::generator::Migration;
5use forge_core::error::{ForgeError, Result};
6
7/// Executes migrations against a database.
8pub struct MigrationExecutor {
9    pool: PgPool,
10}
11
12impl MigrationExecutor {
13    /// Create a new migration executor.
14    pub fn new(pool: PgPool) -> Self {
15        Self { pool }
16    }
17
18    /// Initialize the migrations table.
19    pub async fn init(&self) -> Result<()> {
20        sqlx::query(
21            r#"
22            CREATE TABLE IF NOT EXISTS forge_migrations (
23                id SERIAL PRIMARY KEY,
24                version VARCHAR(255) UNIQUE NOT NULL,
25                name VARCHAR(255),
26                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
27                checksum VARCHAR(64),
28                execution_time_ms INTEGER
29            )
30            "#,
31        )
32        .execute(&self.pool)
33        .await
34        .map_err(|e| ForgeError::Database(format!("Failed to init migrations table: {}", e)))?;
35
36        Ok(())
37    }
38
39    /// Get all applied migrations.
40    pub async fn applied_migrations(&self) -> Result<Vec<AppliedMigration>> {
41        let rows = sqlx::query(
42            r#"
43            SELECT version, name, applied_at, checksum, execution_time_ms
44            FROM forge_migrations
45            ORDER BY version ASC
46            "#,
47        )
48        .fetch_all(&self.pool)
49        .await
50        .map_err(|e| ForgeError::Database(format!("Failed to fetch migrations: {}", e)))?;
51
52        let migrations = rows
53            .iter()
54            .map(|row| AppliedMigration {
55                version: row.get("version"),
56                name: row.get("name"),
57                applied_at: row.get("applied_at"),
58                checksum: row.get("checksum"),
59                execution_time_ms: row.get("execution_time_ms"),
60            })
61            .collect();
62
63        Ok(migrations)
64    }
65
66    /// Check if a migration has been applied.
67    pub async fn is_applied(&self, version: &str) -> Result<bool> {
68        let row = sqlx::query_scalar::<_, i64>(
69            "SELECT COUNT(*) FROM forge_migrations WHERE version = $1",
70        )
71        .bind(version)
72        .fetch_one(&self.pool)
73        .await
74        .map_err(|e| ForgeError::Database(format!("Failed to check migration: {}", e)))?;
75
76        Ok(row > 0)
77    }
78
79    /// Apply a migration.
80    pub async fn apply(&self, migration: &Migration) -> Result<()> {
81        let start = std::time::Instant::now();
82
83        // Execute the migration SQL
84        sqlx::query(&migration.sql)
85            .execute(&self.pool)
86            .await
87            .map_err(|e| {
88                ForgeError::Database(format!(
89                    "Failed to apply migration {}: {}",
90                    migration.version, e
91                ))
92            })?;
93
94        let elapsed = start.elapsed();
95
96        // Calculate checksum
97        let checksum = calculate_checksum(&migration.sql);
98
99        // Record the migration
100        sqlx::query(
101            r#"
102            INSERT INTO forge_migrations (version, name, checksum, execution_time_ms)
103            VALUES ($1, $2, $3, $4)
104            "#,
105        )
106        .bind(&migration.version)
107        .bind(&migration.name)
108        .bind(&checksum)
109        .bind(elapsed.as_millis() as i32)
110        .execute(&self.pool)
111        .await
112        .map_err(|e| {
113            ForgeError::Database(format!(
114                "Failed to record migration {}: {}",
115                migration.version, e
116            ))
117        })?;
118
119        Ok(())
120    }
121
122    /// Rollback the last migration.
123    pub async fn rollback(&self) -> Result<Option<String>> {
124        // Get the last applied migration
125        let last = sqlx::query_scalar::<_, String>(
126            "SELECT version FROM forge_migrations ORDER BY version DESC LIMIT 1",
127        )
128        .fetch_optional(&self.pool)
129        .await
130        .map_err(|e| ForgeError::Database(format!("Failed to get last migration: {}", e)))?;
131
132        match last {
133            Some(version) => {
134                // Remove from migrations table
135                sqlx::query("DELETE FROM forge_migrations WHERE version = $1")
136                    .bind(&version)
137                    .execute(&self.pool)
138                    .await
139                    .map_err(|e| {
140                        ForgeError::Database(format!("Failed to remove migration record: {}", e))
141                    })?;
142
143                Ok(Some(version))
144            }
145            None => Ok(None),
146        }
147    }
148}
149
150/// A migration that has been applied.
151#[derive(Debug, Clone)]
152pub struct AppliedMigration {
153    pub version: String,
154    pub name: Option<String>,
155    pub applied_at: DateTime<Utc>,
156    pub checksum: Option<String>,
157    pub execution_time_ms: Option<i32>,
158}
159
160/// Calculate a SHA256 checksum of the migration content.
161fn calculate_checksum(content: &str) -> String {
162    use std::collections::hash_map::DefaultHasher;
163    use std::hash::{Hash, Hasher};
164
165    let mut hasher = DefaultHasher::new();
166    content.hash(&mut hasher);
167    format!("{:016x}", hasher.finish())
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_checksum() {
176        let checksum = calculate_checksum("CREATE TABLE users (id UUID);");
177        assert_eq!(checksum.len(), 16);
178
179        // Same content should produce same checksum
180        let checksum2 = calculate_checksum("CREATE TABLE users (id UUID);");
181        assert_eq!(checksum, checksum2);
182
183        // Different content should produce different checksum
184        let checksum3 = calculate_checksum("CREATE TABLE posts (id UUID);");
185        assert_ne!(checksum, checksum3);
186    }
187}