use crate::{Db, DbResult, DbError};
use alun_config::MigrationConfig;
use std::path::{Path, PathBuf};
use tracing::{info, warn};
#[derive(Debug, Clone)]
pub struct MigrationRecord {
pub version: String,
pub name: String,
pub executed_at: Option<String>,
pub success: bool,
}
pub struct Migrator {
db: Db,
config: MigrationConfig,
}
impl Migrator {
pub fn new(db: Db, config: MigrationConfig) -> Self {
Self { db, config }
}
pub async fn run(&self) -> DbResult<Vec<MigrationRecord>> {
if !self.config.enabled {
return Ok(vec![]);
}
self.ensure_migrations_table().await?;
let files = self.discover_migration_files().await;
if files.is_empty() {
info!("没有发现待执行的迁移文件");
return Ok(vec![]);
}
let executed = self.get_executed_migrations().await?;
let mut results = Vec::new();
for file in &files {
let version = extract_version(file);
if executed.contains(&version) {
continue;
}
let result = self.execute_migration(file, &version).await;
match result {
Ok(record) => results.push(record),
Err(e) => {
warn!("迁移 {} 执行失败: {}", version, e);
return Err(e);
}
}
}
info!("迁移完成,共执行 {} 个迁移", results.len());
Ok(results)
}
pub async fn rollback(&self) -> DbResult<Option<MigrationRecord>> {
let executed = self.get_executed_migrations().await?;
if executed.is_empty() {
info!("没有可回滚的迁移");
return Ok(None);
}
let last = executed.last().expect("executed vec is non-empty (checked above)").clone();
let down_file = self.find_down_file(&last);
match down_file {
Some(path) => {
info!("回滚迁移: {}", last);
let sql = tokio::fs::read_to_string(&path).await
.map_err(|e| DbError::Other(format!("读取迁移文件失败: {}", e)))?;
self.db.execute(&sql, &[]).await?;
self.mark_migration_rolled_back(&last).await?;
Ok(Some(MigrationRecord { version: last, name: String::new(), executed_at: None, success: true }))
}
None => {
warn!("迁移 {} 没有对应的 down 文件", last);
Ok(None)
}
}
}
async fn ensure_migrations_table(&self) -> DbResult<()> {
let sql = r#"
CREATE TABLE IF NOT EXISTS _migrations (
version VARCHAR(255) PRIMARY KEY,
name VARCHAR(512) NOT NULL DEFAULT '',
executed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
success BOOLEAN DEFAULT TRUE
)
"#;
self.db.execute(sql, &[]).await?;
info!("迁移追踪表 _migrations 已就绪");
Ok(())
}
async fn discover_migration_files(&self) -> Vec<PathBuf> {
let dir = Path::new(&self.config.path);
if !dir.exists() {
return vec![];
}
let mut files = Vec::new();
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("sql") {
let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
if name.ends_with(".up.sql") {
files.push(path);
}
}
}
}
files.sort();
files
}
async fn get_executed_migrations(&self) -> DbResult<Vec<String>> {
let rows = self.db.query(
"SELECT version FROM _migrations WHERE success = TRUE ORDER BY version", &[],
).await?;
Ok(rows.iter().filter_map(|r| r.get_as::<String>("version")).collect())
}
async fn execute_migration(&self, file: &Path, version: &str) -> DbResult<MigrationRecord> {
let name = file.file_name().and_then(|n| n.to_str()).unwrap_or("").to_string();
info!("执行迁移: {}", name);
let sql = tokio::fs::read_to_string(file).await
.map_err(|e| DbError::Other(format!("读取迁移文件失败: {}", e)))?;
self.db.execute(&sql, &[]).await?;
self.db.execute(
"INSERT INTO _migrations (version, name, success) VALUES ($1, $2, TRUE)", &[version, &name],
).await?;
Ok(MigrationRecord { version: version.to_string(), name, executed_at: None, success: true })
}
async fn mark_migration_rolled_back(&self, version: &str) -> DbResult<()> {
self.db.execute(
"DELETE FROM _migrations WHERE version = $1", &[version],
).await?;
Ok(())
}
fn find_down_file(&self, version: &str) -> Option<PathBuf> {
let dir = Path::new(&self.config.path);
if !dir.exists() {
return None;
}
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
if name.starts_with(version) && name.ends_with(".down.sql") {
return Some(path);
}
}
}
None
}
}
fn extract_version(path: &Path) -> String {
path.file_name()
.and_then(|n| n.to_str())
.map(|n| {
if let Some(idx) = n.find('_') {
n[..idx].to_string()
} else {
n.replace(".up.sql", "")
}
})
.unwrap_or_default()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_version() {
let path = Path::new("20240501120000_create_users_table.up.sql");
assert_eq!(extract_version(path), "20240501120000");
}
}