forge_runtime/migrations/
executor.rs1use chrono::{DateTime, Utc};
2use sqlx::PgPool;
3
4use super::generator::Migration;
5use forge_core::error::{ForgeError, Result};
6
7pub struct MigrationExecutor {
9 pool: PgPool,
10}
11
12impl MigrationExecutor {
13 pub fn new(pool: PgPool) -> Self {
15 Self { pool }
16 }
17
18 pub async fn init(&self) -> Result<()> {
20 sqlx::query(
21 r#"
22 CREATE TABLE IF NOT EXISTS forge_migrations (
23 id SERIAL PRIMARY KEY,
24 version VARCHAR(255) UNIQUE NOT NULL,
25 name VARCHAR(255),
26 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
27 checksum VARCHAR(64),
28 execution_time_ms INTEGER
29 )
30 "#,
31 )
32 .execute(&self.pool)
33 .await
34 .map_err(|e| ForgeError::Database(format!("Failed to init migrations table: {}", e)))?;
35
36 Ok(())
37 }
38
39 pub async fn applied_migrations(&self) -> Result<Vec<AppliedMigration>> {
41 let rows = sqlx::query!(
42 r#"
43 SELECT version as "version!", name, applied_at, checksum, execution_time_ms
44 FROM forge_migrations
45 ORDER BY version ASC
46 "#,
47 )
48 .fetch_all(&self.pool)
49 .await
50 .map_err(|e| ForgeError::Database(format!("Failed to fetch migrations: {}", e)))?;
51
52 let migrations = rows
53 .into_iter()
54 .map(|row| AppliedMigration {
55 version: row.version,
56 name: Some(row.name),
57 applied_at: row.applied_at,
58 checksum: row.checksum,
59 execution_time_ms: row.execution_time_ms,
60 })
61 .collect();
62
63 Ok(migrations)
64 }
65
66 pub async fn is_applied(&self, version: &str) -> Result<bool> {
68 let row = sqlx::query_scalar!(
69 "SELECT COUNT(*) FROM forge_migrations WHERE version = $1",
70 version,
71 )
72 .fetch_one(&self.pool)
73 .await
74 .map_err(|e| ForgeError::Database(format!("Failed to check migration: {}", e)))?;
75
76 Ok(row.unwrap_or(0) > 0)
77 }
78
79 pub async fn apply(&self, migration: &Migration) -> Result<()> {
81 let start = std::time::Instant::now();
82
83 sqlx::query(&migration.sql)
85 .execute(&self.pool)
86 .await
87 .map_err(|e| {
88 ForgeError::Database(format!(
89 "Failed to apply migration {}: {}",
90 migration.version, e
91 ))
92 })?;
93
94 let elapsed = start.elapsed();
95
96 let checksum = calculate_checksum(&migration.sql);
98
99 sqlx::query!(
101 r#"
102 INSERT INTO forge_migrations (version, name, checksum, execution_time_ms)
103 VALUES ($1, $2, $3, $4)
104 "#,
105 &migration.version,
106 &migration.name,
107 &checksum,
108 elapsed.as_millis() as i32,
109 )
110 .execute(&self.pool)
111 .await
112 .map_err(|e| {
113 ForgeError::Database(format!(
114 "Failed to record migration {}: {}",
115 migration.version, e
116 ))
117 })?;
118
119 Ok(())
120 }
121
122 pub async fn rollback(&self) -> Result<Option<String>> {
124 let last = sqlx::query_scalar!(
126 r#"SELECT version as "version!" FROM forge_migrations ORDER BY version DESC LIMIT 1"#,
127 )
128 .fetch_optional(&self.pool)
129 .await
130 .map_err(|e| ForgeError::Database(format!("Failed to get last migration: {}", e)))?;
131
132 match last {
133 Some(version) => {
134 sqlx::query!("DELETE FROM forge_migrations WHERE version = $1", &version)
136 .execute(&self.pool)
137 .await
138 .map_err(|e| {
139 ForgeError::Database(format!("Failed to remove migration record: {}", e))
140 })?;
141
142 Ok(Some(version))
143 }
144 None => Ok(None),
145 }
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct AppliedMigration {
152 pub version: String,
153 pub name: Option<String>,
154 pub applied_at: DateTime<Utc>,
155 pub checksum: Option<String>,
156 pub execution_time_ms: Option<i32>,
157}
158
159fn calculate_checksum(content: &str) -> String {
161 use std::collections::hash_map::DefaultHasher;
162 use std::hash::{Hash, Hasher};
163
164 let mut hasher = DefaultHasher::new();
165 content.hash(&mut hasher);
166 format!("{:016x}", hasher.finish())
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn test_checksum() {
175 let checksum = calculate_checksum("CREATE TABLE users (id UUID);");
176 assert_eq!(checksum.len(), 16);
177
178 let checksum2 = calculate_checksum("CREATE TABLE users (id UUID);");
180 assert_eq!(checksum, checksum2);
181
182 let checksum3 = calculate_checksum("CREATE TABLE posts (id UUID);");
184 assert_ne!(checksum, checksum3);
185 }
186}