use super::{Migration, MigrationRecord, MigrationRunner};
use crate::Result;
use chrono::Utc;
use sqlparser::{dialect::MySqlDialect, parser::Parser};
use sqlx::{MySqlPool, Row};
use tracing::{debug, info, warn};
fn parse_mysql_statements(
sql: &str,
) -> std::result::Result<Vec<String>, sqlparser::parser::ParserError> {
let dialect = MySqlDialect {};
match Parser::parse_sql(&dialect, sql) {
Ok(statements) => Ok(statements.iter().map(|stmt| format!("{};", stmt)).collect()),
Err(_) => {
Ok(sql
.split(";\n")
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect())
}
}
}
pub struct MySqlMigrationRunner {
pool: MySqlPool,
}
impl MySqlMigrationRunner {
pub fn new(pool: MySqlPool) -> Self {
Self { pool }
}
}
#[async_trait::async_trait]
impl MigrationRunner<sqlx::MySql> for MySqlMigrationRunner {
async fn run_migration(&self, migration: &Migration, sql: &str) -> Result<()> {
debug!("Executing MySQL migration: {}", migration.id);
let mut tx = self.pool.begin().await?;
let statements = match parse_mysql_statements(sql) {
Ok(stmts) => stmts,
Err(e) => {
warn!(
"Failed to parse MySQL SQL with sqlparser, falling back to naive splitting: {}",
e
);
sql.split(";\n")
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
}
};
for (i, statement) in statements.iter().enumerate() {
let full_statement = if statement.ends_with(';') {
statement.to_string()
} else {
format!("{};", statement)
};
debug!(
"Executing statement {} of {} for migration {}",
i + 1,
statements.len(),
migration.id
);
sqlx::query(&full_statement).execute(&mut *tx).await?;
}
tx.commit().await?;
info!(
"Successfully executed MySQL migration: {} ({} statements)",
migration.id,
statements.len()
);
Ok(())
}
async fn migration_table_exists(&self) -> Result<bool> {
let row = sqlx::query(
"SELECT COUNT(*) as count FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = 'hammerwork_migrations'",
)
.fetch_one(&self.pool)
.await?;
Ok(row.get::<i64, _>("count") > 0)
}
async fn create_migration_table(&self) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS hammerwork_migrations (
migration_id VARCHAR(255) NOT NULL PRIMARY KEY,
executed_at TIMESTAMP(6) NOT NULL,
execution_time_ms BIGINT NOT NULL
)
"#,
)
.execute(&self.pool)
.await?;
info!("Created MySQL migration tracking table");
Ok(())
}
async fn get_executed_migrations(&self) -> Result<Vec<MigrationRecord>> {
let rows = sqlx::query(
"SELECT migration_id, executed_at, execution_time_ms
FROM hammerwork_migrations
ORDER BY executed_at",
)
.fetch_all(&self.pool)
.await?;
let mut records = Vec::new();
for row in rows {
records.push(MigrationRecord {
migration_id: row.get("migration_id"),
executed_at: row.get("executed_at"),
execution_time_ms: row.get::<i64, _>("execution_time_ms") as u64,
});
}
Ok(records)
}
async fn record_migration(&self, migration: &Migration, execution_time_ms: u64) -> Result<()> {
sqlx::query(
"INSERT INTO hammerwork_migrations (migration_id, executed_at, execution_time_ms)
VALUES (?, ?, ?)",
)
.bind(&migration.id)
.bind(Utc::now())
.bind(execution_time_ms as i64)
.execute(&self.pool)
.await?;
debug!("Recorded MySQL migration: {}", migration.id);
Ok(())
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_sql_statement_splitting() {
let multi_statement_sql = r#"
-- Comment line
CREATE TABLE test_table (
id INTEGER PRIMARY KEY
);
-- Another comment
ALTER TABLE test_table ADD COLUMN name VARCHAR(50);
CREATE INDEX idx_test ON test_table (name);
"#;
let statements: Vec<&str> = multi_statement_sql
.split(";\n")
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
assert_eq!(statements.len(), 3);
assert!(statements[0].contains("CREATE TABLE"));
assert!(statements[1].contains("ALTER TABLE"));
assert!(statements[2].contains("CREATE INDEX"));
}
}