use clap::Parser;
use anyhow::Result;
mod server;
pub use server::start_http_server;
#[derive(Parser, Debug)]
#[command(
author = "SQLTool Team",
version,
about = "SQLTool - 智能数据库迁移与运维工具",
long_about = "功能强大的数据库迁移、同步、运维工具,支持:
- 数据库迁移与同步
- 自动分库分表
- 慢查询检测
- 数据对比与备份
- HTTP API 服务模式"
)]
pub struct Args {
#[arg(short, long, default_value_t = false)]
pub verbose: bool,
#[command(subcommand)]
pub command: Command,
}
#[derive(Parser, Debug)]
pub enum Command {
Transfer {
#[arg(short = 's', long)]
source: String,
#[arg(short = 't', long)]
target: String,
#[arg(short = 'S', long, default_value = "mysql")]
source_type: String,
#[arg(short = 'T', long, default_value = "postgresql")]
target_type: String,
#[arg(short, long)]
tables: Option<String>,
#[arg(short = 'B', long, default_value_t = 1000)]
batch_size: usize,
#[arg(short = 'v', long, default_value_t = true)]
verify: bool,
},
MigrateSchema {
#[arg(short = 's', long)]
source: String,
#[arg(short = 't', long)]
target: String,
#[arg(short = 'S', long, default_value = "mysql")]
source_type: String,
#[arg(short = 'T', long, default_value = "postgresql")]
target_type: String,
#[arg(short, long)]
table: String,
},
CompareData {
#[arg(short = 's', long)]
source: String,
#[arg(short = 't', long)]
target: String,
#[arg(short = 'S', long, default_value = "mysql")]
source_type: String,
#[arg(short = 'T', long, default_value = "postgresql")]
target_type: String,
#[arg(short, long)]
table: String,
#[arg(short, long, default_value = "id")]
primary_key: String,
#[arg(short, long)]
ignore_fields: Option<String>,
#[arg(short, long, default_value = "json")]
output: String,
},
Backup {
#[arg(short = 's', long)]
source: String,
#[arg(short = 'T', long, default_value = "mysql")]
db_type: String,
#[arg(short, long)]
output: String,
#[arg(short, long, default_value = "full")]
backup_type: String,
#[arg(short = 'c', long, default_value_t = true)]
compress: bool,
#[arg(long, default_value_t = true)]
include_procedures: bool,
#[arg(long, default_value_t = true)]
include_functions: bool,
#[arg(long, default_value_t = true)]
include_triggers: bool,
},
Restore {
#[arg(short, long)]
backup: String,
#[arg(short = 't', long)]
target: String,
#[arg(short = 'T', long, default_value = "mysql")]
db_type: String,
},
CreateShard {
#[arg(short = 's', long)]
source: String,
#[arg(short, long)]
table: String,
#[arg(short, long, default_value = "row_count")]
strategy: String,
#[arg(short, long)]
threshold: Option<String>,
#[arg(short, long, default_value = "shard")]
prefix: String,
},
SpanningQuery {
#[arg(short = 's', long)]
source: String,
#[arg(short, long)]
table: String,
#[arg(short, long, default_value = "1=1")]
condition: String,
#[arg(short, long)]
order_by: Option<String>,
#[arg(short, long, default_value = "ASC")]
order_dir: String,
#[arg(short = 'L', long, default_value_t = 100)]
limit: u64,
#[arg(short, long, default_value_t = 0)]
offset: u64,
#[arg(short, long, default_value = "json")]
output: String,
},
DetectSlowQuery {
#[arg(short = 's', long)]
source: String,
#[arg(short = 'T', long, default_value = "mysql")]
db_type: String,
#[arg(short, long, default_value_t = 1000)]
threshold_ms: u64,
#[arg(short, long)]
query_file: Option<String>,
#[arg(short, long)]
sql: Option<String>,
#[arg(short, long, default_value = "json")]
output: String,
},
InsertLog {
#[arg(short = 's', long)]
source: String,
#[arg(short, long, default_value = "app_logs")]
table: String,
#[arg(short, long, default_value = "INFO")]
level: String,
#[arg(short, long)]
message: String,
#[arg(short, long)]
source_name: Option<String>,
},
QueryLogs {
#[arg(short = 's', long)]
source: String,
#[arg(short, long, default_value = "app_logs")]
table: String,
#[arg(short, long)]
levels: Option<String>,
#[arg(short, long)]
keyword: Option<String>,
#[arg(short, long)]
start_time: Option<i64>,
#[arg(long)]
end_time: Option<i64>,
#[arg(short = 'L', long, default_value_t = 100)]
limit: u64,
#[arg(short, long, default_value = "json")]
output: String,
},
Server {
#[arg(short, long, default_value = "127.0.0.1")]
host: String,
#[arg(short = 'p', long, default_value_t = 8080)]
port: u16,
#[arg(short = 's', long)]
source: Option<String>,
#[arg(short = 'T', long, default_value = "mysql")]
db_type: String,
#[arg(long, default_value_t = false)]
cors: bool,
#[arg(long)]
api_key: Option<String>,
},
DetectSqlInjection {
#[arg(short, long)]
input: String,
#[arg(short, long, default_value_t = false)]
strict: bool,
},
BuildSafeSql {
#[arg(short, long)]
table: String,
#[arg(short, long)]
field: String,
#[arg(short, long, default_value = "=")]
operator: String,
#[arg(short, long)]
value: String,
},
}
pub async fn execute(args: Args) -> Result<()> {
let verbose = args.verbose;
let command = args.command;
if verbose {
println!("SQLTool v{} - 智能数据库迁移与运维工具", env!("CARGO_PKG_VERSION"));
println!("=======================================\n");
}
match command {
Command::Transfer { source, target, source_type, target_type, .. } => {
validate_connection_string(&source, &source_type)?;
validate_connection_string(&target, &target_type)?;
println!("执行数据迁移命令...");
}
Command::MigrateSchema { source, target, source_type, target_type, .. } => {
validate_connection_string(&source, &source_type)?;
validate_connection_string(&target, &target_type)?;
println!("执行结构迁移命令...");
}
Command::CompareData { source, target, source_type, target_type, .. } => {
validate_connection_string(&source, &source_type)?;
validate_connection_string(&target, &target_type)?;
println!("执行数据对比命令...");
}
Command::Backup { source, db_type, .. } => {
validate_connection_string(&source, &db_type)?;
println!("执行数据库备份命令...");
}
Command::Restore { target, db_type, .. } => {
validate_connection_string(&target, &db_type)?;
println!("执行数据库恢复命令...");
}
Command::CreateShard { source, table, .. } => {
validate_connection_string(&source, "mysql")?;
validate_table_name(&table)?;
println!("执行创建分片命令...");
}
Command::SpanningQuery { source, table, condition, .. } => {
validate_connection_string(&source, "mysql")?;
validate_table_name(&table)?;
validate_condition(&condition)?;
println!("执行跨分片查询命令...");
}
Command::DetectSlowQuery { source, db_type, sql, query_file, .. } => {
validate_connection_string(&source, &db_type)?;
if sql.is_none() && query_file.is_none() {
anyhow::bail!("必须指定 --sql 或 --query-file");
}
println!("执行慢查询检测命令...");
}
Command::InsertLog { source, table, level, message, .. } => {
validate_connection_string(&source, "mysql")?;
validate_table_name(&table)?;
validate_log_level(&level)?;
if message.is_empty() {
anyhow::bail!("日志消息不能为空");
}
println!("执行插入日志命令...");
}
Command::QueryLogs { source, levels, .. } => {
validate_connection_string(&source, "mysql")?;
if let Some(ref l) = levels {
validate_log_level(l)?;
}
println!("执行查询日志命令...");
}
Command::Server { port, source, .. } => {
if let Some(ref s) = source {
validate_connection_string(s, "mysql")?;
}
if port < 1024 {
println!("警告: 端口 {} 小于1024,可能需要root权限", port);
}
println!("启动 HTTP API 服务器...");
start_server(port, source).await?;
}
Command::DetectSqlInjection { input, strict } => {
if input.len() > 10000 {
anyhow::bail!("输入过长,最大10000字符");
}
use crate::utils::SqlInjectionDetector;
let detector = SqlInjectionDetector::new();
let report = detector.detect(&input);
println!("\n========== SQL 注入检测结果 ==========");
println!("输入: {}", input);
if let Some(r) = report {
println!("风险等级: {:?}", r.risk_level);
println!("发现 {} 个问题:", r.findings.len());
for finding in &r.findings {
println!(" - [{:?}] {}", finding.category, finding.description);
}
if strict && matches!(r.risk_level, crate::utils::RiskLevel::High | crate::utils::RiskLevel::Critical) {
anyhow::bail!("检测到高风险 SQL 注入攻击!");
}
} else {
println!("风险等级: None");
println!("未发现 SQL 注入风险");
}
println!("======================================\n");
}
Command::BuildSafeSql { table, field, operator, value } => {
use crate::utils::SafeSqlBuilder;
validate_table_name(&table)?;
validate_field_name(&field)?;
validate_operator(&operator)?;
if value.len() > 10000 {
anyhow::bail!("值过长,最大10000字符");
}
println!("\n========== 安全 SQL 构建结果 ==========");
println!("表: {}", table);
println!("字段: {}", field);
println!("操作符: {}", operator);
println!("原始值: {}", value);
match SafeSqlBuilder::new(&table) {
Ok(builder) => {
match builder.safe_where(&field, &operator, &serde_json::json!(&value)) {
Ok(sql) => {
println!("安全 SQL: {}", sql);
}
Err(e) => {
println!("构建失败: {}", e);
}
}
}
Err(e) => {
println!("构建器创建失败: {}", e);
}
}
println!("======================================\n");
}
}
Ok(())
}
fn validate_connection_string(conn: &str, db_type: &str) -> Result<()> {
use crate::utils::validate_connection_string;
let parsed = validate_connection_string(conn).map_err(|e| {
anyhow::anyhow!("连接字符串验证失败: {}", e)
})?;
if parsed.db_type != db_type {
log::warn!(
"连接字符串类型 {} 与指定类型 {} 不匹配",
parsed.db_type, db_type
);
}
Ok(())
}
fn validate_table_name(table: &str) -> Result<()> {
if table.is_empty() {
anyhow::bail!("表名不能为空");
}
if table.len() > 64 {
anyhow::bail!("表名过长,最大64字符");
}
if !table.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '.') {
anyhow::bail!("表名只能包含字母、数字、下划线和点");
}
Ok(())
}
fn validate_field_name(field: &str) -> Result<()> {
if field.is_empty() {
anyhow::bail!("字段名不能为空");
}
if field.len() > 64 {
anyhow::bail!("字段名过长,最大64字符");
}
if !field.chars().all(|c| c.is_alphanumeric() || c == '_') {
anyhow::bail!("字段名只能包含字母、数字和下划线");
}
Ok(())
}
fn validate_condition(condition: &str) -> Result<()> {
if condition.len() > 2000 {
anyhow::bail!("条件过长,最大2000字符");
}
let dangerous = ["DROP ", "DELETE ", "TRUNCATE ", "ALTER ", "CREATE ", "INSERT ", "UPDATE "];
for kw in dangerous {
if condition.to_uppercase().contains(kw) {
log::warn!("条件包含危险关键词: {}", kw);
}
}
Ok(())
}
fn validate_operator(op: &str) -> Result<()> {
let valid = ["=", "!=", "<>", "<", ">", "<=", ">=", "LIKE", "IN", "BETWEEN", "IS NULL", "IS NOT NULL"];
if !valid.iter().any(|v| v.eq_ignore_ascii_case(op)) {
anyhow::bail!("无效的操作符: {},有效值: {:?}", op, valid);
}
Ok(())
}
fn validate_log_level(level: &str) -> Result<()> {
let valid = ["DEBUG", "INFO", "WARN", "ERROR", "TRACE", "FATAL"];
if !valid.iter().any(|v| v.eq_ignore_ascii_case(level)) {
anyhow::bail!("无效的日志级别: {},有效值: DEBUG, INFO, WARN, ERROR, TRACE, FATAL", level);
}
Ok(())
}
async fn start_server(port: u16, source: Option<String>) -> Result<()> {
let host = "127.0.0.1".to_string();
let db_type = "mysql".to_string();
let cors = false;
let api_key = None;
println!("HTTP API 服务器配置:");
println!(" 监听地址: {}:{}", host, port);
println!(" 数据库类型: {}", db_type);
println!(" CORS 启用: {}", cors);
println!(" API 密钥: {}", if api_key.is_some() { "已设置" } else { "未设置" });
start_http_server(host, port, source, db_type, cors, api_key).await?;
Ok(())
}
pub fn print_help() {
println!(r#"
SQLTool - 智能数据库迁移与运维工具 v{}
用法:
sqltool [选项] <子命令>
子命令:
transfer 数据迁移 - 在两个数据库之间迁移数据
migrate-schema 结构迁移 - 迁移表结构(索引、约束等)
compare-data 数据对比 - 对比两个数据库的数据
backup 数据库备份 - 备份整个数据库
restore 数据库恢复 - 从备份恢复数据库
create-shard 创建分片 - 为大表创建分片
spanning-query 跨分片查询 - 查询多个分片的数据
detect-slow 慢查询检测 - 检测和分析慢查询
insert-log 插入日志 - 向日志表插入日志
query-logs 查询日志 - 查询日志表数据
server HTTP API 服务器 - 启动 REST API 服务
detect-injection SQL注入检测 - 检测 SQL 注入风险
build-safe-sql 安全SQL构建 - 构建安全的 SQL 语句
选项:
-v, --verbose 启用详细输出
-h, --help 显示帮助信息
-V, --version 显示版本信息
示例:
# 数据迁移
sqltool transfer -s mysql://root:pass@localhost:3306/source_db \\
-t postgresql://postgres:pass@localhost:5432/target_db
# 数据库备份
sqltool backup -s mysql://root:pass@localhost:3306/mydb \\
--output ./backup.sql
# 启动 API 服务器
sqltool server -p 8080 -s mysql://root:pass@localhost:3306/mydb
更多信息: https://github.com/yourusername/sqltool
"#, env!("CARGO_PKG_VERSION"));
}