Skip to main content

forge_runtime/migrations/
executor.rs

1use chrono::{DateTime, Utc};
2use sqlx::PgPool;
3
4use super::generator::Migration;
5use forge_core::error::{ForgeError, Result};
6
7/// Executes migrations against a database.
8pub struct MigrationExecutor {
9    pool: PgPool,
10}
11
12impl MigrationExecutor {
13    /// Create a new migration executor.
14    pub fn new(pool: PgPool) -> Self {
15        Self { pool }
16    }
17
18    /// Initialize the migrations table.
19    pub async fn init(&self) -> Result<()> {
20        sqlx::query(
21            r#"
22            CREATE TABLE IF NOT EXISTS forge_migrations (
23                id SERIAL PRIMARY KEY,
24                version VARCHAR(255) UNIQUE NOT NULL,
25                name VARCHAR(255),
26                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
27                checksum VARCHAR(64),
28                execution_time_ms INTEGER
29            )
30            "#,
31        )
32        .execute(&self.pool)
33        .await
34        .map_err(|e| ForgeError::Database(format!("Failed to init migrations table: {}", e)))?;
35
36        Ok(())
37    }
38
39    /// Get all applied migrations.
40    pub async fn applied_migrations(&self) -> Result<Vec<AppliedMigration>> {
41        let rows = sqlx::query!(
42            r#"
43            SELECT version as "version!", name, applied_at, checksum, execution_time_ms
44            FROM forge_migrations
45            ORDER BY version ASC
46            "#,
47        )
48        .fetch_all(&self.pool)
49        .await
50        .map_err(|e| ForgeError::Database(format!("Failed to fetch migrations: {}", e)))?;
51
52        let migrations = rows
53            .into_iter()
54            .map(|row| AppliedMigration {
55                version: row.version,
56                name: Some(row.name),
57                applied_at: row.applied_at,
58                checksum: row.checksum,
59                execution_time_ms: row.execution_time_ms,
60            })
61            .collect();
62
63        Ok(migrations)
64    }
65
66    /// Check if a migration has been applied.
67    pub async fn is_applied(&self, version: &str) -> Result<bool> {
68        let row = sqlx::query_scalar!(
69            "SELECT COUNT(*) FROM forge_migrations WHERE version = $1",
70            version,
71        )
72        .fetch_one(&self.pool)
73        .await
74        .map_err(|e| ForgeError::Database(format!("Failed to check migration: {}", e)))?;
75
76        Ok(row.unwrap_or(0) > 0)
77    }
78
79    /// Apply a migration.
80    pub async fn apply(&self, migration: &Migration) -> Result<()> {
81        let start = std::time::Instant::now();
82
83        // Execute the migration SQL
84        sqlx::query(&migration.sql)
85            .execute(&self.pool)
86            .await
87            .map_err(|e| {
88                ForgeError::Database(format!(
89                    "Failed to apply migration {}: {}",
90                    migration.version, e
91                ))
92            })?;
93
94        let elapsed = start.elapsed();
95
96        // Calculate checksum
97        let checksum = calculate_checksum(&migration.sql);
98
99        // Record the migration
100        sqlx::query!(
101            r#"
102            INSERT INTO forge_migrations (version, name, checksum, execution_time_ms)
103            VALUES ($1, $2, $3, $4)
104            "#,
105            &migration.version,
106            &migration.name,
107            &checksum,
108            elapsed.as_millis() as i32,
109        )
110        .execute(&self.pool)
111        .await
112        .map_err(|e| {
113            ForgeError::Database(format!(
114                "Failed to record migration {}: {}",
115                migration.version, e
116            ))
117        })?;
118
119        Ok(())
120    }
121
122    /// Rollback the last migration.
123    pub async fn rollback(&self) -> Result<Option<String>> {
124        // Get the last applied migration
125        let last = sqlx::query_scalar!(
126            r#"SELECT version as "version!" FROM forge_migrations ORDER BY version DESC LIMIT 1"#,
127        )
128        .fetch_optional(&self.pool)
129        .await
130        .map_err(|e| ForgeError::Database(format!("Failed to get last migration: {}", e)))?;
131
132        match last {
133            Some(version) => {
134                // Remove from migrations table
135                sqlx::query!("DELETE FROM forge_migrations WHERE version = $1", &version)
136                    .execute(&self.pool)
137                    .await
138                    .map_err(|e| {
139                        ForgeError::Database(format!("Failed to remove migration record: {}", e))
140                    })?;
141
142                Ok(Some(version))
143            }
144            None => Ok(None),
145        }
146    }
147}
148
149/// A migration that has been applied.
150#[derive(Debug, Clone)]
151pub struct AppliedMigration {
152    pub version: String,
153    pub name: Option<String>,
154    pub applied_at: DateTime<Utc>,
155    pub checksum: Option<String>,
156    pub execution_time_ms: Option<i32>,
157}
158
159/// Calculate a SHA256 checksum of the migration content.
160fn calculate_checksum(content: &str) -> String {
161    use std::collections::hash_map::DefaultHasher;
162    use std::hash::{Hash, Hasher};
163
164    let mut hasher = DefaultHasher::new();
165    content.hash(&mut hasher);
166    format!("{:016x}", hasher.finish())
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_checksum() {
175        let checksum = calculate_checksum("CREATE TABLE users (id UUID);");
176        assert_eq!(checksum.len(), 16);
177
178        // Same content should produce same checksum
179        let checksum2 = calculate_checksum("CREATE TABLE users (id UUID);");
180        assert_eq!(checksum, checksum2);
181
182        // Different content should produce different checksum
183        let checksum3 = calculate_checksum("CREATE TABLE posts (id UUID);");
184        assert_ne!(checksum, checksum3);
185    }
186}