use crate::container::DockerManager;
use crate::sql_diff::{TableColumn, TableDefinition, TableIndex};
use anyhow::{Context, Result, anyhow};
use docker_compose_types as dct;
use mysql_async::prelude::*;
use mysql_async::{Opts, Pool, Row, Transaction, TxOpts, from_row};
pub struct MySqlExecutor {
pool: Pool,
config: MySqlConfig,
}
#[derive(Debug, Clone)]
pub struct MySqlConfig {
pub host: String,
pub port: u16,
pub user: String,
pub password: String,
pub database: String,
}
impl MySqlConfig {
pub async fn for_container(compose_file: Option<&str>, env_file: Option<&str>) -> Result<Self> {
let docker_manager = match (compose_file, env_file) {
(Some(c), Some(e)) => DockerManager::with_project(c, e, None)?,
_ => {
return Err(anyhow!(
"docker-compose.yml and .env paths are required to load Docker Compose configuration"
));
}
};
let compose_config = docker_manager
.load_compose_config()
.context("Failed to load Docker Compose configuration")?;
let mysql_service = compose_config
.services
.0
.get("mysql")
.and_then(|s| s.as_ref())
.ok_or_else(|| anyhow!("'mysql' service not found in docker-compose.yml"))?;
let mut config_map = std::collections::HashMap::new();
if let dct::Environment::List(env_list) = &mysql_service.environment {
for item in env_list {
if let Some((key, value)) = item.split_once('=') {
config_map.insert(key.to_string(), value.to_string());
}
}
}
let port = match &mysql_service.ports {
dct::Ports::Short(ports_list) => ports_list
.iter()
.find_map(|p| {
let parts: Vec<&str> = p.split(':').collect();
if parts.len() == 2 && parts[1] == "3306" {
parts[0].parse::<u16>().ok()
} else {
None
}
})
.ok_or_else(|| anyhow!("No mapping to container port 3306 found in 'mysql' service"))?,
dct::Ports::Long(ports_list) => ports_list
.iter()
.find_map(|p| {
if p.target == 3306 {
match &p.published {
Some(dct::PublishedPort::Single(port_num)) => Some(*port_num),
Some(dct::PublishedPort::Range(port_str)) => {
port_str.parse::<u16>().ok()
}
None => None,
}
} else {
None
}
})
.ok_or_else(|| anyhow!("No mapping to container port 3306 found in 'mysql' service"))?,
_ => return Err(anyhow!("Unsupported ports format or undefined ports in 'mysql' service")),
};
Ok(MySqlConfig {
host: "127.0.0.1".to_string(),
port,
user: config_map
.get("MYSQL_USER")
.cloned()
.unwrap_or_else(|| "root".to_string()),
password: config_map
.get("MYSQL_PASSWORD")
.cloned()
.unwrap_or_else(|| "root".to_string()),
database: config_map
.get("MYSQL_DATABASE")
.cloned()
.unwrap_or_else(|| "agent_platform".to_string()),
})
}
fn to_url(&self) -> String {
format!(
"mysql://{}:{}@{}:{}/{}",
self.user, self.password, self.host, self.port, self.database
)
}
}
impl MySqlExecutor {
pub fn new(config: MySqlConfig) -> Self {
let opts = Opts::from_url(&config.to_url()).unwrap();
let pool = Pool::new(opts);
Self { pool, config }
}
pub async fn test_connection(&self) -> Result<(), mysql_async::Error> {
let mut conn = self.pool.get_conn().await?;
conn.query_drop("SELECT 1").await?;
Ok(())
}
pub async fn execute_single(&self, sql: &str) -> Result<u64, mysql_async::Error> {
let mut conn = self.pool.get_conn().await?;
let result = conn.query_iter(sql).await?;
Ok(result.affected_rows())
}
pub async fn execute_diff_sql(&self, sql_content: &str) -> Result<Vec<String>, anyhow::Error> {
self.execute_diff_sql_with_retry(sql_content, 1).await
}
pub async fn execute_diff_sql_with_retry(
&self,
sql_content: &str,
max_retries: u8,
) -> Result<Vec<String>, anyhow::Error> {
let sql_lines = self.parse_sql_commands(sql_content);
let mut results = Vec::new();
let mut last_error: Option<mysql_async::Error> = None;
for attempt in 0..=max_retries {
if attempt > 0 {
tokio::time::sleep(std::time::Duration::from_millis(500 * attempt as u64)).await;
results.push(format!("🔄 Retrying attempt {attempt}/{max_retries}..."));
}
let mut conn = self.pool.get_conn().await?;
let mut tx = conn.start_transaction(TxOpts::default()).await?;
let results_len_before_attempt = results.len();
match self
.execute_in_transaction(&mut tx, &sql_lines, &mut results)
.await
{
Ok(_) => {
tx.commit().await?;
results.insert(0, "✅ Diff SQL executed successfully".to_string());
return Ok(results);
}
Err(e) => {
tx.rollback().await?;
results.truncate(results_len_before_attempt);
results.push(format!("❌ Attempt {} failed: {}", attempt + 1, e));
last_error = Some(e);
}
}
}
Err(anyhow::anyhow!(
"❌ SQL execution failed after {} attempts. Last error: {}",
max_retries + 1,
last_error.unwrap()
))
}
async fn execute_in_transaction<'a>(
&self,
tx: &mut Transaction<'a>,
lines: &[String],
results: &mut Vec<String>,
) -> Result<(), mysql_async::Error> {
for (idx, sql) in lines.iter().enumerate() {
if sql.starts_with("--") || sql.trim().is_empty() {
continue;
}
tx.query_drop(sql).await?;
results.push(format!("[{}] ✅ {}", idx + 1, sql));
}
Ok(())
}
fn parse_sql_commands(&self, sql_content: &str) -> Vec<String> {
let mut commands = Vec::new();
let mut current_command = String::new();
for line in sql_content.lines() {
let line = line.trim();
if line.starts_with("--") || line.is_empty() {
continue;
}
current_command.push_str(line);
current_command.push(' ');
if line.ends_with(';') || line.ends_with("ENGINE=InnoDB;") || line.ends_with(");") {
commands.push(current_command.trim().to_string());
current_command.clear();
}
}
if !current_command.trim().is_empty() {
commands.push(current_command.trim().to_string());
}
commands
}
pub async fn get_table_info(&self, table_name: &str) -> Result<(), mysql_async::Error> {
let mut conn = self.pool.get_conn().await?;
let results: Vec<Row> = conn.query(format!("DESCRIBE {table_name}")).await?;
for row in results {
println!("{row:?}");
}
Ok(())
}
pub async fn fetch_live_schema(
&self,
) -> Result<std::collections::HashMap<String, TableDefinition>, anyhow::Error> {
let (tables, _sql) = self.fetch_live_schema_with_sql().await?;
Ok(tables)
}
pub async fn fetch_live_schema_with_sql(
&self,
) -> Result<(std::collections::HashMap<String, TableDefinition>, String), anyhow::Error> {
use crate::sql_diff::parse_sql_tables;
let mut conn = self.pool.get_conn().await?;
let table_names: Vec<String> = conn
.exec(
r#"SELECT TABLE_NAME
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = ?
ORDER BY TABLE_NAME"#,
(self.config.database.clone(),),
)
.await?
.into_iter()
.map(|row| {
let (name,): (String,) = from_row(row);
name
})
.collect();
let mut create_sqls = String::new();
for table in &table_names {
let query = format!("SHOW CREATE TABLE `{}`", table);
let row: Row = conn
.exec_first(query, ())
.await?
.ok_or_else(|| anyhow::anyhow!(format!("Failed to get CREATE statement for table: {}", table)))?;
let (_tbl_name, create_stmt): (String, String) = from_row(row);
create_sqls.push_str(&create_stmt);
if !create_stmt.trim().ends_with(';') {
create_sqls.push(';');
}
create_sqls.push_str("\n\n");
}
let tables = parse_sql_tables(&create_sqls)
.map_err(|e| anyhow::anyhow!(format!("Failed to parse online DDL: {}", e)))?;
Ok((tables, create_sqls))
}
pub async fn verify_execution(
&self,
_expected_changes: &str,
) -> Result<bool, mysql_async::Error> {
let mut conn = self.pool.get_conn().await?;
let result: Option<(i32,)> = conn.query_first("SELECT 1 as verification_status").await?;
if let Some((1,)) = result {
Ok(true)
} else {
Ok(false)
}
}
pub async fn health_check(&self) -> HealthStatus {
match self.test_connection().await {
Ok(_) => HealthStatus::Healthy,
Err(e) => HealthStatus::Failed(e.to_string()),
}
}
}
#[derive(Debug, Clone)]
pub enum HealthStatus {
Healthy,
Failed(String),
}
#[derive(Debug, Clone)]
pub struct ExecutionResult {
pub sql: String,
pub status: bool,
pub rows_affected: Option<u64>,
pub error: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mysql_connection() {
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
let compose_path = std::path::Path::new(&manifest_dir).join("fixtures/docker-compose.yml");
let env_path = std::path::Path::new(&manifest_dir).join("fixtures/.env");
let config = MySqlConfig::for_container(
Some(compose_path.to_str().unwrap()),
Some(env_path.to_str().unwrap()),
)
.await
.unwrap();
let executor = MySqlExecutor::new(config);
if executor.test_connection().await.is_ok() {
executor
.execute_single("CREATE DATABASE IF NOT EXISTS test_db")
.await
.unwrap();
executor.execute_single("USE test_db").await.unwrap();
executor
.execute_single(
"CREATE TABLE IF NOT EXISTS test_table (id INT PRIMARY KEY, name VARCHAR(255))",
)
.await
.unwrap();
let results = executor
.execute_diff_sql("CREATE TABLE IF NOT EXISTS users (id INT PRIMARY KEY AUTO_INCREMENT, name VARCHAR(100)); \
ALTER TABLE users ADD COLUMN email VARCHAR(255); \
CREATE INDEX idx_name ON users(name);")
.await
.unwrap();
assert!(!results.is_empty());
println!("✅ MySQL执行器测试通过");
executor
.execute_single("DROP DATABASE IF EXISTS test_db")
.await
.unwrap();
} else {
println!("⚠️ MySQL容器未运行,跳过测试");
}
}
#[tokio::test]
async fn test_parse_sql_commands() {
let content = "-- 注释\n\
CREATE TABLE users (id INT);\n\
ALTER TABLE users ADD COLUMN name VARCHAR(100);\n\
CREATE INDEX idx_name ON users(name);";
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
let compose_path = std::path::Path::new(&manifest_dir).join("fixtures/docker-compose.yml");
let env_path = std::path::Path::new(&manifest_dir).join("fixtures/.env");
let config = MySqlConfig::for_container(
Some(compose_path.to_str().unwrap()),
Some(env_path.to_str().unwrap()),
)
.await
.unwrap();
let executor = MySqlExecutor::new(config);
let commands = executor.parse_sql_commands(content);
assert_eq!(commands.len(), 3);
assert!(commands[0].contains("CREATE TABLE users"));
assert!(commands[1].contains("ALTER TABLE users ADD COLUMN name"));
}
#[tokio::test]
async fn test_empty_and_comments() {
let content = "-- This is a comment\n\nCREATE TABLE test (id INT);\n-- Another comment";
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
let compose_path = std::path::Path::new(&manifest_dir).join("fixtures/docker-compose.yml");
let env_path = std::path::Path::new(&manifest_dir).join("fixtures/.env");
let config = MySqlConfig::for_container(
Some(compose_path.to_str().unwrap()),
Some(env_path.to_str().unwrap()),
)
.await
.unwrap();
let executor = MySqlExecutor::new(config);
let commands = executor.parse_sql_commands(content);
assert_eq!(commands.len(), 1);
assert_eq!(commands[0], "CREATE TABLE test (id INT);");
}
#[test]
fn test_table_name_normalization() {
use crate::sql_diff::parse_sql_tables;
let sql_with_backticks = "CREATE TABLE `test_table` (\n `id` int NOT NULL AUTO_INCREMENT,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB;";
let sql_without_backticks = "CREATE TABLE test_table (\n id int NOT NULL AUTO_INCREMENT,\n PRIMARY KEY (id)\n) ENGINE=InnoDB;";
let tables1 = parse_sql_tables(sql_with_backticks).expect("解析带反引号的 SQL 失败");
let tables2 = parse_sql_tables(sql_without_backticks).expect("解析不带反引号的 SQL 失败");
assert!(
tables1.contains_key("test_table"),
"带反引号的表名应该被标准化为 test_table"
);
assert!(
tables2.contains_key("test_table"),
"不带反引号的表名应该是 test_table"
);
assert!(
!tables1.contains_key("`test_table`"),
"不应该有带反引号的表名作为 key"
);
assert!(
!tables2.contains_key("`test_table`"),
"不应该有带反引号的表名作为 key"
);
println!("✅ 表名标准化测试通过");
}
#[test]
fn test_sql_diff_with_same_tables() {
use crate::sql_diff::{generate_schema_diff, parse_sql_tables};
let mysql_sql = "CREATE TABLE `custom_page_config` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `name` varchar(255) NOT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB;";
let file_sql = "CREATE TABLE custom_page_config (\n id bigint NOT NULL AUTO_INCREMENT,\n name varchar(255) NOT NULL,\n PRIMARY KEY (id)\n) ENGINE=InnoDB;";
let mysql_tables = parse_sql_tables(mysql_sql).expect("解析 MySQL SQL 失败");
let file_tables = parse_sql_tables(file_sql).expect("解析文件 SQL 失败");
println!("MySQL 表: {:?}", mysql_tables.keys().collect::<Vec<_>>());
println!("文件表: {:?}", file_tables.keys().collect::<Vec<_>>());
let (diff_sql, description) =
generate_schema_diff(Some(mysql_sql), file_sql, Some("在线架构"), "目标版本")
.expect("生成差异 SQL 失败");
println!("差异描述: {}", description);
println!("差异 SQL:\n{}", diff_sql);
let meaningful_lines: Vec<&str> = diff_sql
.lines()
.filter(|line| !line.trim().is_empty() && !line.trim().starts_with("--"))
.collect();
assert!(
meaningful_lines.is_empty(),
"相同的表不应该产生差异 SQL,但生成了: {:?}",
meaningful_lines
);
println!("✅ SQL diff 测试通过:相同的表没有产生差异");
}
#[test]
fn test_create_table_concatenation_with_semicolons() {
let mut create_sqls = String::new();
let stmt1 = "CREATE TABLE `agent_config` (\n `id` int NOT NULL AUTO_INCREMENT,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB";
create_sqls.push_str(stmt1);
if !stmt1.trim().ends_with(';') {
create_sqls.push(';');
}
create_sqls.push_str("\n\n");
let stmt2 = "CREATE TABLE `agent_component_config` (\n `id` int NOT NULL AUTO_INCREMENT,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB";
create_sqls.push_str(stmt2);
if !stmt2.trim().ends_with(';') {
create_sqls.push(';');
}
create_sqls.push_str("\n\n");
println!("拼接后的 SQL:\n{}", create_sqls);
assert!(
create_sqls.contains("ENGINE=InnoDB;"),
"第一个表的语句应该以分号结尾"
);
assert!(
create_sqls.matches("ENGINE=InnoDB;").count() == 2,
"两个表的语句都应该以分号结尾"
);
use crate::sql_diff::parse_sql_tables;
let result = parse_sql_tables(&create_sqls);
if let Err(ref e) = result {
println!("解析错误: {}", e);
}
assert!(
result.is_ok(),
"拼接后的 SQL 应该可以被正确解析: {:?}",
result.err()
);
let tables = result.unwrap();
println!("解析出的表: {:?}", tables.keys().collect::<Vec<_>>());
assert_eq!(
tables.len(),
2,
"应该解析出 2 个表,实际解析出 {} 个",
tables.len()
);
let has_agent_config =
tables.contains_key("agent_config") || tables.contains_key("`agent_config`");
let has_agent_component_config = tables.contains_key("agent_component_config")
|| tables.contains_key("`agent_component_config`");
assert!(has_agent_config, "应该包含 agent_config 表");
assert!(
has_agent_component_config,
"应该包含 agent_component_config 表"
);
println!("✅ CREATE TABLE 语句拼接测试通过");
}
}