forge_runtime/migrations/
executor.rs1use chrono::{DateTime, Utc};
2use sqlx::{PgPool, Row};
3
4use super::generator::Migration;
5use forge_core::error::{ForgeError, Result};
6
7pub struct MigrationExecutor {
9 pool: PgPool,
10}
11
12impl MigrationExecutor {
13 pub fn new(pool: PgPool) -> Self {
15 Self { pool }
16 }
17
18 pub async fn init(&self) -> Result<()> {
20 sqlx::query(
21 r#"
22 CREATE TABLE IF NOT EXISTS forge_migrations (
23 id SERIAL PRIMARY KEY,
24 version VARCHAR(255) UNIQUE NOT NULL,
25 name VARCHAR(255),
26 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
27 checksum VARCHAR(64),
28 execution_time_ms INTEGER
29 )
30 "#,
31 )
32 .execute(&self.pool)
33 .await
34 .map_err(|e| ForgeError::Database(format!("Failed to init migrations table: {}", e)))?;
35
36 Ok(())
37 }
38
39 pub async fn applied_migrations(&self) -> Result<Vec<AppliedMigration>> {
41 let rows = sqlx::query(
42 r#"
43 SELECT version, name, applied_at, checksum, execution_time_ms
44 FROM forge_migrations
45 ORDER BY version ASC
46 "#,
47 )
48 .fetch_all(&self.pool)
49 .await
50 .map_err(|e| ForgeError::Database(format!("Failed to fetch migrations: {}", e)))?;
51
52 let migrations = rows
53 .iter()
54 .map(|row| AppliedMigration {
55 version: row.get("version"),
56 name: row.get("name"),
57 applied_at: row.get("applied_at"),
58 checksum: row.get("checksum"),
59 execution_time_ms: row.get("execution_time_ms"),
60 })
61 .collect();
62
63 Ok(migrations)
64 }
65
66 pub async fn is_applied(&self, version: &str) -> Result<bool> {
68 let row = sqlx::query_scalar::<_, i64>(
69 "SELECT COUNT(*) FROM forge_migrations WHERE version = $1",
70 )
71 .bind(version)
72 .fetch_one(&self.pool)
73 .await
74 .map_err(|e| ForgeError::Database(format!("Failed to check migration: {}", e)))?;
75
76 Ok(row > 0)
77 }
78
79 pub async fn apply(&self, migration: &Migration) -> Result<()> {
81 let start = std::time::Instant::now();
82
83 sqlx::query(&migration.sql)
85 .execute(&self.pool)
86 .await
87 .map_err(|e| {
88 ForgeError::Database(format!(
89 "Failed to apply migration {}: {}",
90 migration.version, e
91 ))
92 })?;
93
94 let elapsed = start.elapsed();
95
96 let checksum = calculate_checksum(&migration.sql);
98
99 sqlx::query(
101 r#"
102 INSERT INTO forge_migrations (version, name, checksum, execution_time_ms)
103 VALUES ($1, $2, $3, $4)
104 "#,
105 )
106 .bind(&migration.version)
107 .bind(&migration.name)
108 .bind(&checksum)
109 .bind(elapsed.as_millis() as i32)
110 .execute(&self.pool)
111 .await
112 .map_err(|e| {
113 ForgeError::Database(format!(
114 "Failed to record migration {}: {}",
115 migration.version, e
116 ))
117 })?;
118
119 Ok(())
120 }
121
122 pub async fn rollback(&self) -> Result<Option<String>> {
124 let last = sqlx::query_scalar::<_, String>(
126 "SELECT version FROM forge_migrations ORDER BY version DESC LIMIT 1",
127 )
128 .fetch_optional(&self.pool)
129 .await
130 .map_err(|e| ForgeError::Database(format!("Failed to get last migration: {}", e)))?;
131
132 match last {
133 Some(version) => {
134 sqlx::query("DELETE FROM forge_migrations WHERE version = $1")
136 .bind(&version)
137 .execute(&self.pool)
138 .await
139 .map_err(|e| {
140 ForgeError::Database(format!("Failed to remove migration record: {}", e))
141 })?;
142
143 Ok(Some(version))
144 }
145 None => Ok(None),
146 }
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct AppliedMigration {
153 pub version: String,
154 pub name: Option<String>,
155 pub applied_at: DateTime<Utc>,
156 pub checksum: Option<String>,
157 pub execution_time_ms: Option<i32>,
158}
159
160fn calculate_checksum(content: &str) -> String {
162 use std::collections::hash_map::DefaultHasher;
163 use std::hash::{Hash, Hasher};
164
165 let mut hasher = DefaultHasher::new();
166 content.hash(&mut hasher);
167 format!("{:016x}", hasher.finish())
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_checksum() {
176 let checksum = calculate_checksum("CREATE TABLE users (id UUID);");
177 assert_eq!(checksum.len(), 16);
178
179 let checksum2 = calculate_checksum("CREATE TABLE users (id UUID);");
181 assert_eq!(checksum, checksum2);
182
183 let checksum3 = calculate_checksum("CREATE TABLE posts (id UUID);");
185 assert_ne!(checksum, checksum3);
186 }
187}