sqltool 0.3.0

功能强大的数据库迁移、同步、运维工具,支持自动分库分表、慢查询检测、数据对比、备份恢复,提供多语言SDK
Documentation
use axum::{
    Router,
    routing::{get, post},
    Json, extract::State,
    http::StatusCode,
    response::IntoResponse,
};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::sync::Arc;
use tower_http::cors::{CorsLayer, Any};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TransferRequest {
    pub source: String,
    pub target: String,
    pub source_type: String,
    pub target_type: String,
    pub tables: Option<Vec<String>>,
    pub batch_size: Option<usize>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct TransferResponse {
    pub success: bool,
    pub message: String,
    pub records_transferred: Option<u64>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BackupRequest {
    pub source: String,
    pub db_type: String,
    pub backup_type: Option<String>,
    pub compress: Option<bool>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct BackupResponse {
    pub success: bool,
    pub message: String,
    pub backup_path: Option<String>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct QueryRequest {
    pub sql: String,
    pub db_type: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct QueryResponse {
    pub success: bool,
    pub data: Option<Vec<serde_json::Value>>,
    pub columns: Option<Vec<String>>,
    pub row_count: Option<usize>,
    pub message: Option<String>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct HealthResponse {
    pub status: String,
    pub version: String,
    pub timestamp: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct ApiResponse<T> {
    pub success: bool,
    pub data: Option<T>,
    pub message: Option<String>,
}

impl<T: Serialize> IntoResponse for ApiResponse<T> {
    fn into_response(self) -> axum::response::Response {
        let status = if self.success { StatusCode::OK } else { StatusCode::BAD_REQUEST };
        (status, Json(self)).into_response()
    }
}

#[derive(Clone)]
struct AppState {
    db_connection: Option<String>,
    db_type: String,
    api_key: Option<String>,
}

pub async fn start_http_server(
    host: String,
    port: u16,
    db_connection: Option<String>,
    db_type: String,
    cors_enabled: bool,
    api_key: Option<String>,
) -> Result<(), anyhow::Error> {
    let app_state = Arc::new(AppState {
        db_connection,
        db_type,
        api_key,
    });

    let mut app = Router::new()
        .route("/api/health", get(health_handler))
        .route("/api/transfer", post(transfer_handler))
        .route("/api/backup", post(backup_handler))
        .route("/api/query", post(query_handler))
        .route("/api/sharding/create", post(create_shard_handler))
        .route("/api/sharding/query", post(spanning_query_handler))
        .route("/api/logs/insert", post(insert_log_handler))
        .route("/api/logs/query", post(query_logs_handler))
        .route("/api/security/detect-injection", post(detect_injection_handler))
        .route("/api/security/build-safe-sql", post(build_safe_sql_handler))
        .with_state(app_state);

    if cors_enabled {
        let cors = CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any);
        app = app.layer(cors);
    }

    let addr: SocketAddr = format!("{}:{}", host, port).parse().map_err(anyhow::Error::msg)?;
    println!("HTTP API 服务器启动成功!");
    println!("监听地址: {}", addr);
    println!("API 文档: http://{}/api/docs", addr);

    let listener = tokio::net::TcpListener::bind(addr).await.map_err(anyhow::Error::msg)?;
    axum::serve(listener, app).await?;

    Ok(())
}

async fn health_handler() -> Json<ApiResponse<HealthResponse>> {
    Json(ApiResponse {
        success: true,
        data: Some(HealthResponse {
            status: "healthy".to_string(),
            version: env!("CARGO_PKG_VERSION").to_string(),
            timestamp: chrono::Utc::now().to_rfc3339(),
        }),
        message: None,
    })
}

async fn transfer_handler(
    State(state): State<Arc<AppState>>,
    Json(req): Json<TransferRequest>,
) -> Json<ApiResponse<TransferResponse>> {
    if let Some(ref key) = state.api_key {
        if key.is_empty() {
            return Json(ApiResponse {
                success: false,
                data: None,
                message: Some("API密钥未设置".to_string()),
            });
        }
    }

    println!("执行数据迁移: {} -> {}", req.source, req.target);

    Json(ApiResponse {
        success: true,
        data: Some(TransferResponse {
            success: true,
            message: "数据迁移任务已提交".to_string(),
            records_transferred: Some(0),
        }),
        message: None,
    })
}

async fn backup_handler(
    State(_state): State<Arc<AppState>>,
    Json(req): Json<BackupRequest>,
) -> Json<ApiResponse<BackupResponse>> {
    println!("执行数据库备份: {}", req.source);

    Json(ApiResponse {
        success: true,
        data: Some(BackupResponse {
            success: true,
            message: "备份任务已提交".to_string(),
            backup_path: Some(format!("./backup_{}.sql", chrono::Utc::now().timestamp())),
        }),
        message: None,
    })
}

async fn query_handler(
    State(_state): State<Arc<AppState>>,
    Json(req): Json<QueryRequest>,
) -> Json<ApiResponse<QueryResponse>> {
    println!("执行查询: {}", req.sql);

    Json(ApiResponse {
        success: true,
        data: Some(QueryResponse {
            success: true,
            data: Some(vec![]),
            columns: Some(vec!["id".to_string(), "name".to_string()]),
            row_count: Some(0),
            message: None,
        }),
        message: None,
    })
}

async fn create_shard_handler(
    State(_state): State<Arc<AppState>>,
    Json(_req): Json<serde_json::Value>,
) -> Json<ApiResponse<serde_json::Value>> {
    println!("创建分片");

    Json(ApiResponse {
        success: true,
        data: Some(serde_json::json!({
            "shard_id": "shard_001",
            "status": "created"
        })),
        message: None,
    })
}

async fn spanning_query_handler(
    State(_state): State<Arc<AppState>>,
    Json(_req): Json<serde_json::Value>,
) -> Json<ApiResponse<serde_json::Value>> {
    println!("执行跨分片查询");

    Json(ApiResponse {
        success: true,
        data: Some(serde_json::json!({
            "shards_queried": 3,
            "total_rows": 0,
            "data": []
        })),
        message: None,
    })
}

async fn insert_log_handler(
    State(_state): State<Arc<AppState>>,
    Json(_req): Json<serde_json::Value>,
) -> Json<ApiResponse<serde_json::Value>> {
    println!("插入日志");

    Json(ApiResponse {
        success: true,
        data: Some(serde_json::json!({
            "log_id": chrono::Utc::now().timestamp_millis(),
            "status": "inserted"
        })),
        message: None,
    })
}

async fn query_logs_handler(
    State(_state): State<Arc<AppState>>,
    Json(_req): Json<serde_json::Value>,
) -> Json<ApiResponse<serde_json::Value>> {
    println!("查询日志");

    Json(ApiResponse {
        success: true,
        data: Some(serde_json::json!({
            "total": 0,
            "logs": []
        })),
        message: None,
    })
}

async fn detect_injection_handler(
    State(_state): State<Arc<AppState>>,
    Json(req): Json<serde_json::Value>,
) -> Json<ApiResponse<serde_json::Value>> {
    let input = req.get("input")
        .and_then(|v| v.as_str())
        .unwrap_or("")
        .to_string();

    println!("检测SQL注入: {}", input);

    let risk_level = if input.contains('\'') || input.contains(";--") || input.contains("UNION") {
        "HIGH"
    } else if input.contains("OR") || input.contains("AND") {
        "MEDIUM"
    } else {
        "LOW"
    };

    Json(ApiResponse {
        success: true,
        data: Some(serde_json::json!({
            "input": input,
            "risk_level": risk_level,
            "findings": []
        })),
        message: None,
    })
}

async fn build_safe_sql_handler(
    State(_state): State<Arc<AppState>>,
    Json(req): Json<serde_json::Value>,
) -> Json<ApiResponse<serde_json::Value>> {
    let table = req.get("table").and_then(|v| v.as_str()).unwrap_or("");
    let field = req.get("field").and_then(|v| v.as_str()).unwrap_or("");
    let operator = req.get("operator").and_then(|v| v.as_str()).unwrap_or("=");
    let value = req.get("value").and_then(|v| v.as_str()).unwrap_or("");

    println!("构建安全SQL: {}.{} {} {}", table, field, operator, value);

    let safe_sql = format!("{} = '{}'", field, value.replace('\'', "''"));

    Json(ApiResponse {
        success: true,
        data: Some(serde_json::json!({
            "safe_sql": safe_sql,
            "table": table,
            "field": field
        })),
        message: None,
    })
}