use chrono::{DateTime, Utc};
use sqlx::{PgPool, Row};
use super::generator::Migration;
use forge_core::error::{ForgeError, Result};
pub struct MigrationExecutor {
pool: PgPool,
}
impl MigrationExecutor {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
pub async fn init(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS forge_migrations (
id SERIAL PRIMARY KEY,
version VARCHAR(255) UNIQUE NOT NULL,
name VARCHAR(255),
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
checksum VARCHAR(64),
execution_time_ms INTEGER
)
"#,
)
.execute(&self.pool)
.await
.map_err(|e| ForgeError::Database(format!("Failed to init migrations table: {}", e)))?;
Ok(())
}
pub async fn applied_migrations(&self) -> Result<Vec<AppliedMigration>> {
let rows = sqlx::query(
r#"
SELECT version, name, applied_at, checksum, execution_time_ms
FROM forge_migrations
ORDER BY version ASC
"#,
)
.fetch_all(&self.pool)
.await
.map_err(|e| ForgeError::Database(format!("Failed to fetch migrations: {}", e)))?;
let migrations = rows
.iter()
.map(|row| AppliedMigration {
version: row.get("version"),
name: row.get("name"),
applied_at: row.get("applied_at"),
checksum: row.get("checksum"),
execution_time_ms: row.get("execution_time_ms"),
})
.collect();
Ok(migrations)
}
pub async fn is_applied(&self, version: &str) -> Result<bool> {
let row = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM forge_migrations WHERE version = $1",
)
.bind(version)
.fetch_one(&self.pool)
.await
.map_err(|e| ForgeError::Database(format!("Failed to check migration: {}", e)))?;
Ok(row > 0)
}
pub async fn apply(&self, migration: &Migration) -> Result<()> {
let start = std::time::Instant::now();
sqlx::query(&migration.sql)
.execute(&self.pool)
.await
.map_err(|e| {
ForgeError::Database(format!(
"Failed to apply migration {}: {}",
migration.version, e
))
})?;
let elapsed = start.elapsed();
let checksum = calculate_checksum(&migration.sql);
sqlx::query(
r#"
INSERT INTO forge_migrations (version, name, checksum, execution_time_ms)
VALUES ($1, $2, $3, $4)
"#,
)
.bind(&migration.version)
.bind(&migration.name)
.bind(&checksum)
.bind(elapsed.as_millis() as i32)
.execute(&self.pool)
.await
.map_err(|e| {
ForgeError::Database(format!(
"Failed to record migration {}: {}",
migration.version, e
))
})?;
Ok(())
}
pub async fn rollback(&self) -> Result<Option<String>> {
let last = sqlx::query_scalar::<_, String>(
"SELECT version FROM forge_migrations ORDER BY version DESC LIMIT 1",
)
.fetch_optional(&self.pool)
.await
.map_err(|e| ForgeError::Database(format!("Failed to get last migration: {}", e)))?;
match last {
Some(version) => {
sqlx::query("DELETE FROM forge_migrations WHERE version = $1")
.bind(&version)
.execute(&self.pool)
.await
.map_err(|e| {
ForgeError::Database(format!("Failed to remove migration record: {}", e))
})?;
Ok(Some(version))
}
None => Ok(None),
}
}
}
#[derive(Debug, Clone)]
pub struct AppliedMigration {
pub version: String,
pub name: Option<String>,
pub applied_at: DateTime<Utc>,
pub checksum: Option<String>,
pub execution_time_ms: Option<i32>,
}
fn calculate_checksum(content: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
format!("{:016x}", hasher.finish())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checksum() {
let checksum = calculate_checksum("CREATE TABLE users (id UUID);");
assert_eq!(checksum.len(), 16);
let checksum2 = calculate_checksum("CREATE TABLE users (id UUID);");
assert_eq!(checksum, checksum2);
let checksum3 = calculate_checksum("CREATE TABLE posts (id UUID);");
assert_ne!(checksum, checksum3);
}
}