use chrono::{DateTime, Utc};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
use std::fs;
use std::path::Path;
use super::definitions::{Migration, MigrationConfig};
use crate::error::{OrmError, OrmResult};
pub struct MigrationManager {
config: MigrationConfig,
}
impl MigrationManager {
pub fn new() -> Self {
Self::with_config(MigrationConfig::default())
}
pub fn with_config(config: MigrationConfig) -> Self {
Self { config }
}
pub fn config(&self) -> &MigrationConfig {
&self.config
}
pub async fn create_migration(&self, name: &str) -> OrmResult<String> {
fs::create_dir_all(&self.config.migrations_dir).map_err(|e| {
OrmError::Migration(format!("Failed to create migrations directory: {}", e))
})?;
let timestamp = Utc::now().format("%Y%m%d_%H%M%S").to_string();
let migration_id = format!("{}_{}", timestamp, name.replace(' ', "_").to_lowercase());
let filename = format!("{}.sql", migration_id);
let filepath = self.config.migrations_dir.join(&filename);
let template = self.create_migration_template(name, &migration_id);
fs::write(&filepath, template)
.map_err(|e| OrmError::Migration(format!("Failed to write migration file: {}", e)))?;
Ok(filename)
}
pub async fn load_migrations(&self) -> OrmResult<Vec<Migration>> {
if !self.config.migrations_dir.exists() {
return Ok(Vec::new());
}
let mut migrations = Vec::new();
let entries = fs::read_dir(&self.config.migrations_dir).map_err(|e| {
OrmError::Migration(format!("Failed to read migrations directory: {}", e))
})?;
for entry in entries {
let entry = entry.map_err(|e| {
OrmError::Migration(format!("Failed to read directory entry: {}", e))
})?;
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "sql") {
let migration = self.parse_migration_file(&path).await?;
migrations.push(migration);
}
}
migrations.sort_by(|a, b| a.id.cmp(&b.id));
Ok(migrations)
}
async fn parse_migration_file(&self, path: &Path) -> OrmResult<Migration> {
let content = fs::read_to_string(path)
.map_err(|e| OrmError::Migration(format!("Failed to read migration file: {}", e)))?;
let filename = path
.file_stem()
.and_then(|s| s.to_str())
.ok_or_else(|| OrmError::Migration("Invalid migration filename".to_string()))?;
let parts: Vec<&str> = filename.split('_').collect();
if parts.len() < 2 {
return Err(OrmError::Migration(
"Migration filename must follow format: timestamp_name".to_string(),
));
}
let id = filename.to_string();
let name = if parts.len() >= 3 && parts[0].len() == 8 && parts[1].len() == 6 {
parts[2..].join("_").replace('_', " ")
} else {
parts[1..].join("_").replace('_', " ")
};
let (up_sql, down_sql) = self.parse_migration_content(&content)?;
let created_at = self
.parse_migration_timestamp(parts[0])
.unwrap_or_else(|_| Utc::now());
Ok(Migration {
id,
name,
up_sql,
down_sql,
created_at,
})
}
fn parse_migration_content(&self, content: &str) -> OrmResult<(String, String)> {
let lines: Vec<&str> = content.lines().collect();
let mut up_sql = Vec::new();
let mut down_sql = Vec::new();
let mut current_section = "";
for line in lines {
let trimmed = line.trim().to_lowercase();
if trimmed.starts_with("-- up") || trimmed.contains("up migration") {
current_section = "up";
continue;
} else if trimmed.starts_with("-- down") || trimmed.contains("down migration") {
current_section = "down";
continue;
}
if line.trim().is_empty() || line.trim().starts_with("--") {
continue;
}
match current_section {
"up" => up_sql.push(line),
"down" => down_sql.push(line),
_ => {} }
}
Ok((
up_sql.join("\n").trim().to_string(),
down_sql.join("\n").trim().to_string(),
))
}
fn parse_migration_timestamp(
&self,
timestamp_str: &str,
) -> Result<DateTime<Utc>, chrono::ParseError> {
let formatted = format!("{}000000", ×tamp_str[..8]); let naive = chrono::NaiveDateTime::parse_from_str(&formatted, "%Y%m%d%H%M%S")?;
Ok(DateTime::from_naive_utc_and_offset(naive, Utc))
}
fn create_migration_template(&self, name: &str, migration_id: &str) -> String {
format!(
"-- Migration: {}\n\
-- ID: {}\n\
-- Created: {}\n\n\
-- Up migration\n\
-- Add your schema changes here\n\n\n\
-- Down migration \n\
-- Add rollback statements here\n\n",
name,
migration_id,
Utc::now().format("%Y-%m-%d %H:%M:%S UTC")
)
}
pub fn split_sql_statements(&self, sql: &str) -> OrmResult<Vec<String>> {
let dialect = GenericDialect {};
let mut statements = Vec::new();
match Parser::parse_sql(&dialect, sql) {
Ok(parsed_statements) => {
for stmt in parsed_statements {
statements.push(format!("{};", stmt));
}
Ok(statements)
}
Err(e) => {
tracing::warn!("SQL parsing failed, using naive semicolon splitting: {}", e);
let naive_statements = sql
.split(';')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| format!("{};", s))
.collect();
Ok(naive_statements)
}
}
}
pub fn create_migrations_table_sql(&self) -> String {
format!(
"CREATE TABLE IF NOT EXISTS {} (\n \
id VARCHAR(255) PRIMARY KEY,\n \
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,\n \
batch INTEGER NOT NULL\n\
);",
self.config.migrations_table
)
}
pub fn check_migration_sql(&self, migration_id: &str) -> (String, Vec<String>) {
(
format!(
"SELECT id FROM {} WHERE id = $1",
self.config.migrations_table
),
vec![migration_id.to_string()],
)
}
pub fn record_migration_sql(&self, migration_id: &str, batch: i32) -> (String, Vec<String>) {
(
format!(
"INSERT INTO {} (id, applied_at, batch) VALUES ($1, $2::timestamp, $3::integer)",
self.config.migrations_table
),
vec![
migration_id.to_string(),
Utc::now().to_rfc3339(),
batch.to_string(),
],
)
}
pub fn remove_migration_sql(&self, migration_id: &str) -> (String, Vec<String>) {
(
format!("DELETE FROM {} WHERE id = $1", self.config.migrations_table),
vec![migration_id.to_string()],
)
}
pub fn get_latest_batch_sql(&self) -> String {
format!(
"SELECT COALESCE(MAX(batch), 0) FROM {}",
self.config.migrations_table
)
}
pub fn get_applied_migrations_sql(&self) -> String {
format!(
"SELECT id, applied_at, batch FROM {} ORDER BY batch DESC, applied_at DESC",
self.config.migrations_table
)
}
}
impl Default for MigrationManager {
fn default() -> Self {
Self::new()
}
}