use clap::{Parser, Subcommand};
use dbnexus::foundation::DatabaseType as MigrationDatabaseType;
use dbnexus::{DbError, DbPool, DbResult};
use dbnexus::{MigrationExecutor, MigrationFile, MigrationFileParser};
#[cfg(feature = "sql-parser")]
use dbnexus::{SqlOperationType, SqlParser};
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Parser)]
#[command(name = "dbnexus-migrate")]
#[command(about = "DBNexus 数据库迁移工具", long_about = None)]
struct Cli {
#[arg(short, long, env = "DATABASE_URL")]
database_url: String,
#[arg(short, long)]
config: Option<PathBuf>,
#[arg(short, long, default_value = "./migrations")]
migrations_dir: PathBuf,
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
Create {
description: String,
#[arg(short, long, default_value = "./migrations")]
directory: PathBuf,
},
Up {
#[arg(long)]
version: Option<u32>,
},
Down {
#[arg(long)]
version: Option<u32>,
#[arg(long, default_value = "false")]
all: bool,
},
Status,
TestConnection,
Generate {
#[arg(long)]
from_schema: Option<PathBuf>,
#[arg(long)]
to_schema: Option<PathBuf>,
#[arg(short, long, default_value = "./migrations/generated.sql")]
output: PathBuf,
#[arg(short, long, default_value = "auto_generated")]
description: String,
},
List,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
if !cli.migrations_dir.exists() {
fs::create_dir_all(&cli.migrations_dir).map_err(|e| DbError::Config(format!("无法创建迁移目录: {}", e)))?;
}
match &cli.command {
Commands::Create { description, directory } => {
create_migration(description, directory).await?;
}
Commands::Up { version } => {
run_migrations_up(&cli.database_url, &cli.migrations_dir, *version).await?;
}
Commands::Down { version, all } => {
run_migrations_down(&cli.database_url, *version, *all).await?;
}
Commands::Status => {
show_status(&cli.database_url, &cli.migrations_dir).await?;
}
Commands::TestConnection => {
test_connection(&cli.database_url).await?;
}
Commands::Generate {
from_schema,
to_schema,
output,
description,
} => {
generate_migration(from_schema, to_schema, output, description).await?;
}
Commands::List => {
list_migrations(&cli.database_url, &cli.migrations_dir).await?;
}
}
Ok(())
}
async fn create_migration(description: &str, directory: &Path) -> DbResult<()> {
fs::create_dir_all(directory).map_err(|e| DbError::Config(format!("无法创建目录: {}", e)))?;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| DbError::Config(format!("无法解析时间戳: {}", e)))?
.as_secs();
let sanitized_description = description
.chars()
.filter(|c| c.is_alphanumeric() || *c == '_' || *c == '-')
.collect::<String>();
if sanitized_description.is_empty() {
return Err(DbError::Config("迁移描述不能只包含特殊字符".to_string()));
}
if sanitized_description.len() > 100 {
return Err(DbError::Config("迁移描述过长(最大 100 字符)".to_string()));
}
let filename = format!("{}_{}.sql", timestamp, sanitized_description);
let filepath = directory.join(&filename);
let migration_content = format!(
r#"-- Migration: {description}
-- Version: {timestamp}
-- Created: {created_at}
-- UP: Apply migration
-- Your migration SQL goes here
-- DOWN: Rollback migration
-- Reversal of migration SQL goes here
"#,
description = description,
timestamp = timestamp,
created_at = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S")
);
fs::write(&filepath, migration_content).map_err(|e| DbError::Config(format!("无法写入迁移文件: {}", e)))?;
println!("✓ 迁移文件已创建: {}", filepath.display());
Ok(())
}
async fn show_status(database_url: &str, migrations_dir: &Path) -> DbResult<()> {
println!("\n╔══════════════════════════════════════════════════════════════╗");
println!("║ 迁移状态查看 ║");
println!("╚══════════════════════════════════════════════════════════════╝");
let pool = match DbPool::new(database_url).await {
Ok(pool) => pool,
Err(e) => {
println!("\n❌ 数据库连接失败: {}", e);
return Ok(());
}
};
let db_type =
detect_database_type(database_url).map_err(|e| DbError::Config(format!("数据库类型检测失败: {}", e)))?;
println!("\n📊 数据库类型: {}", db_type);
println!("📁 迁移目录: {}", migrations_dir.display());
let session = match pool.get_session("admin").await {
Ok(session) => session,
Err(e) => {
println!("\n❌ 无法获取数据库会话: {}", e);
return Ok(());
}
};
let mut executor = session.create_migration_executor(db_type)?;
if let Err(e) = executor.load_history().await {
println!("\n⚠️ 无法加载迁移历史: {}", e);
println!(" 迁移历史表可能不存在");
return Ok(());
}
let applied_count = executor.history().applied_migrations.len();
println!("\n✅ 已应用的迁移: {} 个", applied_count);
if applied_count > 0 {
if let Some(latest_version) = executor.history().get_latest_version() {
if let Some(latest_migration) = executor
.history()
.applied_migrations
.iter()
.find(|m| m.version == latest_version)
{
println!(" 最新迁移:");
println!(" - 版本: {}", latest_migration.version);
println!(" - 描述: {}", latest_migration.description);
println!(" - 应用时间: {}", latest_migration.applied_at);
}
}
println!("\n 迁移历史详情:");
for (idx, migration) in executor.history().applied_migrations.iter().enumerate() {
println!(
" [{:2}] v{:6} - {}",
idx + 1,
migration.version,
migration.description
);
}
}
let local_migrations = executor.scan_migrations(migrations_dir)?;
let pending_count = local_migrations
.iter()
.filter(|m| !executor.history().is_version_applied(m.version()))
.count();
println!("\n📦 本地迁移文件: {} 个", local_migrations.len());
println!("⏳ 待应用的迁移: {} 个", pending_count);
if !local_migrations.is_empty() {
let applied_versions: std::collections::HashSet<u32> = executor
.history()
.applied_migrations
.iter()
.map(|m| m.version)
.collect();
let pending: Vec<_> = local_migrations
.iter()
.filter(|m| !applied_versions.contains(&m.version()))
.collect();
if !pending.is_empty() {
println!("\n 待应用迁移列表:");
for (idx, migration) in pending.iter().enumerate() {
println!(
" [{:2}] v{:6} - {}",
idx + 1,
migration.version(),
migration.description()
);
}
} else {
println!("\n ✓ 所有迁移都已应用");
}
}
println!("\n🔗 数据库连接: 已连接");
println!(" URL: {}", mask_database_url(database_url));
println!("\n{}", "─".repeat(60));
Ok(())
}
async fn test_connection(database_url: &str) -> DbResult<()> {
println!("\n╔══════════════════════════════════════════════════════════════╗");
println!("║ 数据库连接测试 ║");
println!("╚══════════════════════════════════════════════════════════════╝");
println!("\n🔄 正在测试数据库连接...");
let start_time = std::time::Instant::now();
let pool = match DbPool::new(database_url).await {
Ok(pool) => pool,
Err(e) => {
println!("\n❌ 连接失败: {}", e);
return Err(e);
}
};
let elapsed = start_time.elapsed();
match pool.get_session("admin").await {
Ok(session) => {
let _conn = session.connection()?.clone();
drop(session);
let db_type = detect_database_type(database_url)
.map_err(|e| DbError::Connection(sea_orm::DbErr::Custom(format!("数据库类型检测失败: {}", e))))?;
println!("\n✅ 连接成功!");
println!("\n 数据库类型: {}", db_type);
println!(" 连接耗时: {:?}", elapsed);
println!(" 连接URL: {}", mask_database_url(database_url));
println!("\n 连接池状态:");
let status = pool.status();
println!(" - 总连接数: {}", status.total);
println!(" - 活跃连接: {}", status.active);
println!(" - 空闲连接: {}", status.idle);
}
Err(e) => {
println!("\n❌ 连接验证失败: {}", e);
}
}
println!("\n{}", "─".repeat(60));
Ok(())
}
async fn run_migrations_up(database_url: &str, migrations_dir: &Path, target_version: Option<u32>) -> DbResult<()> {
println!("\n╔══════════════════════════════════════════════════════════════╗");
println!("║ 应用迁移 ║");
println!("╚══════════════════════════════════════════════════════════════╝");
let pool = DbPool::new(database_url).await?;
let db_type = detect_database_type(database_url)?;
println!("\n📊 数据库类型: {}", db_type);
println!("📁 迁移目录: {}", migrations_dir.display());
let session = pool.get_session("admin").await?;
let mut executor = session.create_migration_executor(db_type)?;
let migrations = executor.scan_migrations(migrations_dir)?;
if migrations.is_empty() {
println!("\n⚠️ 迁移目录中没有找到迁移文件");
return Ok(());
}
executor.load_history().await?;
let applied_versions: std::collections::HashSet<u32> = executor
.history()
.applied_migrations
.iter()
.map(|m| m.version)
.collect();
let mut to_apply: Vec<_> = migrations
.iter()
.filter(|m| !applied_versions.contains(&m.version()))
.filter(|m| {
if let Some(target) = target_version {
m.version() <= target
} else {
true
}
})
.collect();
to_apply.sort_by_key(|m| m.version());
if to_apply.is_empty() {
println!("\n✓ 没有待应用的迁移");
return Ok(());
}
println!("\n📦 找到 {} 个待应用迁移", to_apply.len());
if let Some(target) = target_version {
println!(" 目标版本: {}", target);
}
println!("\n🚀 开始应用迁移...");
let mut success_count = 0;
for migration in &to_apply {
print!(
" 正在应用 v{} - {} ... ",
migration.version(),
migration.description()
);
match executor.apply_migration_file_public(migration).await {
Ok(_) => {
println!("✓");
success_count += 1;
}
Err(e) => {
println!("❌ 失败: {}", e);
return Err(e);
}
}
}
println!("\n✅ 成功应用 {} / {} 个迁移", success_count, to_apply.len());
println!("\n{}", "─".repeat(60));
Ok(())
}
async fn run_migrations_down(database_url: &str, target_version: Option<u32>, rollback_all: bool) -> DbResult<()> {
println!("\n╔══════════════════════════════════════════════════════════════╗");
println!("║ 回滚迁移 ║");
println!("╚══════════════════════════════════════════════════════════════╝");
let pool = DbPool::new(database_url).await?;
let db_type = detect_database_type(database_url)?;
println!("\n📊 数据库类型: {}", db_type);
let session = pool.get_session("admin").await?;
let mut executor = session.create_migration_executor(db_type)?;
executor.load_history().await?;
let applied_migrations = &executor.history().applied_migrations;
if applied_migrations.is_empty() {
println!("\n⚠️ 没有已应用的迁移可以回滚");
return Ok(());
}
let versions_to_rollback: Vec<u32> = if rollback_all {
applied_migrations.iter().map(|m| m.version).collect()
} else if let Some(target) = target_version {
applied_migrations
.iter()
.filter(|m| m.version >= target)
.map(|m| m.version)
.collect()
} else {
if let Some(max_version) = applied_migrations.iter().map(|m| m.version).max() {
vec![max_version]
} else {
Vec::new() }
};
let mut versions_to_rollback = versions_to_rollback;
versions_to_rollback.sort_by_key(|v| std::cmp::Reverse(*v));
println!("\n📦 需要回滚 {} 个迁移", versions_to_rollback.len());
if rollback_all {
println!(" 模式: 回滚所有迁移");
} else if let Some(target) = target_version {
println!(" 模式: 回滚到版本 {}", target);
} else {
println!(" 模式: 回滚上一个版本");
}
println!("\n🔄 开始回滚迁移...");
let mut success_count = 0;
let rollback_info: Vec<(u32, String)> = versions_to_rollback
.iter()
.filter_map(|version| {
applied_migrations
.iter()
.find(|m| m.version == *version)
.map(|info| (info.version, info.description.clone()))
})
.collect();
for (version, description) in &rollback_info {
print!(" 正在回滚 v{} - {} ... ", version, description);
match rollback_migration(&mut executor, *version, db_type).await {
Ok(_) => {
println!("✓");
success_count += 1;
}
Err(e) => {
println!("❌ 失败: {}", e);
println!("\n⚠️ 回滚过程中发生错误,停止执行");
return Err(DbError::Migration(format!(
"Migration rollback failed for v{}: {}",
version, e
)));
}
}
}
println!(
"\n✅ 成功回滚 {} / {} 个迁移",
success_count,
versions_to_rollback.len()
);
println!("\n{}", "─".repeat(60));
Ok(())
}
async fn rollback_migration(
executor: &mut MigrationExecutor,
version: u32,
db_type: MigrationDatabaseType,
) -> DbResult<()> {
use sea_orm::{ConnectionTrait, DatabaseTransaction, TransactionTrait};
let backend = match db_type {
MigrationDatabaseType::Postgres => sea_orm::DbBackend::Postgres,
MigrationDatabaseType::MySql => sea_orm::DbBackend::MySql,
MigrationDatabaseType::Sqlite => sea_orm::DbBackend::Sqlite,
MigrationDatabaseType::DuckDb => {
return Err(DbError::Config(
"Migration rollback for DuckDB is not supported via SeaORM backend".to_string(),
));
}
};
let delete_sql = sea_orm::Statement::from_sql_and_values(
backend,
"DELETE FROM dbnexus_migrations WHERE version = ?".to_string(),
vec![version.into()],
);
let conn = &executor.connection;
let txn: DatabaseTransaction = TransactionTrait::begin(conn).await.map_err(DbError::Connection)?;
txn.execute_raw(delete_sql).await.map_err(DbError::Connection)?;
txn.commit().await.map_err(DbError::Connection)?;
Ok(())
}
async fn generate_migration(
from_schema: &Option<PathBuf>,
to_schema: &Option<PathBuf>,
output: &Path,
description: &str,
) -> DbResult<()> {
println!("\n╔══════════════════════════════════════════════════════════════╗");
println!("║ 生成迁移文件 ║");
println!("╚══════════════════════════════════════════════════════════════╝");
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| DbError::Config(format!("无法解析时间戳: {}", e)))?
.as_secs();
let migration_content;
if let (Some(from), Some(to)) = (from_schema, to_schema) {
println!("\n📄 解析 Schema 文件...");
let from_content =
fs::read_to_string(from).map_err(|e| DbError::Config(format!("无法读取源 schema 文件: {}", e)))?;
let to_content =
fs::read_to_string(to).map_err(|e| DbError::Config(format!("无法读取目标 schema 文件: {}", e)))?;
let diff_sql = generate_schema_diff_sql(&from_content, &to_content)?;
migration_content = format!(
r#"-- Migration: {description}
-- Version: {timestamp}
-- Created: {created_at}
-- Type: Auto-generated from schema diff
-- UP: Apply migration
{up_sql}
-- DOWN: Rollback migration
{down_sql}
"#,
description = description,
timestamp = timestamp,
created_at = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S"),
up_sql = diff_sql.up,
down_sql = diff_sql.down
);
println!("✓ 已生成 schema 差异 SQL");
} else {
migration_content = format!(
r#"-- Migration: {description}
-- Version: {timestamp}
-- Created: {created_at}
-- Type: Manual migration
-- UP: Apply migration
-- Your migration SQL goes here
-- DOWN: Rollback migration
-- Reversal of migration SQL goes here
"#,
description = description,
timestamp = timestamp,
created_at = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S")
);
println!("⚠️ 未提供 schema 文件,已生成空白模板");
}
if let Some(parent) = output.parent() {
if !parent.exists() {
fs::create_dir_all(parent).map_err(|e| DbError::Config(format!("无法创建输出目录: {}", e)))?;
}
}
fs::write(output, migration_content).map_err(|e| DbError::Config(format!("无法写入迁移文件: {}", e)))?;
println!("\n✓ 迁移文件已生成: {}", output.display());
if from_schema.is_some() && to_schema.is_some() {
println!(" 请检查并编辑生成的迁移文件以确保正确性");
}
println!("\n{}", "─".repeat(60));
Ok(())
}
struct DiffSql {
up: String,
down: String,
}
fn generate_schema_diff_sql(_from_content: &str, _to_content: &str) -> Result<DiffSql, DbError> {
Ok(DiffSql {
up: "-- 自动生成的 UP SQL 请手动编辑".to_string(),
down: "-- 自动生成的 DOWN SQL 请手动编辑".to_string(),
})
}
#[allow(dead_code)]
async fn parse_and_apply_migration(
session: &mut dbnexus::Session,
executor: &mut MigrationExecutor,
content: &str,
version: u32,
) -> DbResult<()> {
let (description, _full_content) =
MigrationFileParser::parse_migration_file(content).unwrap_or(("Migration".to_string(), content.to_string()));
let up_sql = extract_sql_section(content, "UP")?;
if !up_sql.trim().is_empty() {
let sql_upper = up_sql.trim().to_uppercase();
let dangerous_patterns = [("DROP DATABASE", "DROP DATABASE"), ("TRUNCATE TABLE", "TRUNCATE TABLE")];
for (pattern, description) in &dangerous_patterns {
if sql_upper.contains(pattern) {
return Err(DbError::Migration(format!(
"Forbidden pattern in migration SQL: {} ({})",
pattern, description
)));
}
}
session.execute_raw(&up_sql).await?;
}
let file_path = format!("migration_v{}.sql", version);
let migration_file = MigrationFile::new(
version,
description,
std::path::PathBuf::from(&file_path),
String::new(), );
executor.apply_migration_file_public(&migration_file).await?;
Ok(())
}
#[allow(dead_code)]
fn extract_sql_section(content: &str, section: &str) -> Result<String, DbError> {
let section_start_pattern = format!("-- {}:", section);
let section_end_pattern = format!("-- {}", if section == "UP" { "DOWN" } else { "UP" });
let start_match = content.find(§ion_start_pattern);
let end_match = content.find(§ion_end_pattern);
if let Some(start_idx) = start_match {
let line_end = content[start_idx..]
.find('\n')
.map(|offset| start_idx + offset + 1) .unwrap_or(start_idx + section_start_pattern.len());
if let Some(end_idx) = end_match {
if end_idx > start_idx {
Ok(content[line_end..end_idx].trim().to_string())
} else {
Ok(content[line_end..].trim().to_string())
}
} else {
Ok(content[line_end..].trim().to_string())
}
} else {
Ok(String::new())
}
}
async fn list_migrations(database_url: &str, migrations_dir: &Path) -> DbResult<()> {
println!("\n╔══════════════════════════════════════════════════════════════╗");
println!("║ 迁移文件列表 ║");
println!("╚══════════════════════════════════════════════════════════════╝");
let pool = DbPool::new(database_url).await?;
let db_type = detect_database_type(database_url)?;
let session = pool.get_session("admin").await?;
let executor = session.create_migration_executor(db_type)?;
let migrations = executor.scan_migrations(migrations_dir)?;
if migrations.is_empty() {
println!("\n⚠️ 迁移目录中没有找到迁移文件");
println!(" 目录: {}", migrations_dir.display());
return Ok(());
}
println!("\n📁 迁移目录: {}", migrations_dir.display());
println!("📦 共 {} 个迁移文件\n", migrations.len());
for (idx, migration) in migrations.iter().enumerate() {
println!(
" [{:2}] v{:6} - {}",
idx + 1,
migration.version(),
migration.description()
);
}
println!("\n{}", "─".repeat(60));
Ok(())
}
fn detect_database_type(database_url: &str) -> Result<MigrationDatabaseType, DbError> {
let url =
url::Url::parse(database_url).map_err(|e| DbError::Config(format!("Invalid database URL format: {}", e)))?;
let scheme = url.scheme().to_lowercase();
match scheme.as_str() {
"postgres" | "postgresql" => Ok(MigrationDatabaseType::Postgres),
"mysql" => Ok(MigrationDatabaseType::MySql),
"sqlite" | "sqlite3" | "file" => Ok(MigrationDatabaseType::Sqlite),
"oci" | "oracle" => Err(DbError::Config("Oracle database is not supported".to_string())),
"mssql" | "sqlserver" => Err(DbError::Config("SQL Server database is not supported".to_string())),
_ => Err(DbError::Config(format!(
"Unsupported database protocol: '{}'. Supported protocols: sqlite, postgres, mysql",
scheme
))),
}
}
fn mask_database_url(url: &str) -> String {
url::Url::parse(url)
.map(|mut url| {
if let Some(password) = url.password() {
url.set_password(Some(&"*".repeat(password.len()))).ok();
}
url.to_string()
})
.unwrap_or_else(|_| url.to_string())
}