1use crate::container::DockerManager;
2use crate::sql_diff::TableDefinition;
3use anyhow::{Context, Result, anyhow};
4use docker_compose_types as dct;
5use mysql_async::prelude::*;
6use mysql_async::{Opts, Pool, Row, Transaction, TxOpts, from_row};
7
8pub struct MySqlExecutor {
11 pool: Pool,
12 config: MySqlConfig,
13}
14
15#[derive(Debug, Clone)]
17pub struct MySqlConfig {
18 pub host: String,
19 pub port: u16,
20 pub user: String,
21 pub password: String,
22 pub database: String,
23}
24
25impl MySqlConfig {
26 pub async fn for_container(compose_file: Option<&str>, env_file: Option<&str>) -> Result<Self> {
28 let docker_manager = match (compose_file, env_file) {
29 (Some(c), Some(e)) => DockerManager::with_project(c, e, None)?,
30 _ => {
31 return Err(anyhow!(
32 "docker-compose.yml and .env paths are required to load Docker Compose configuration"
33 ));
34 }
35 };
36 let compose_config = docker_manager
37 .load_compose_config()
38 .context("Failed to load Docker Compose configuration")?;
39
40 let mysql_service = compose_config
41 .services
42 .0
43 .get("mysql")
44 .and_then(|s| s.as_ref())
45 .ok_or_else(|| anyhow!("'mysql' service not found in docker-compose.yml"))?;
46
47 let mut config_map = std::collections::HashMap::new();
48 if let dct::Environment::List(env_list) = &mysql_service.environment {
49 for item in env_list {
50 if let Some((key, value)) = item.split_once('=') {
51 config_map.insert(key.to_string(), value.to_string());
52 }
53 }
54 }
55
56 let port = match &mysql_service.ports {
57 dct::Ports::Short(ports_list) => ports_list
58 .iter()
59 .find_map(|p| {
60 let parts: Vec<&str> = p.split(':').collect();
61 if parts.len() == 2 && parts[1] == "3306" {
62 parts[0].parse::<u16>().ok()
63 } else {
64 None
65 }
66 })
67 .ok_or_else(|| {
68 anyhow!("No mapping to container port 3306 found in 'mysql' service")
69 })?,
70 dct::Ports::Long(ports_list) => ports_list
71 .iter()
72 .find_map(|p| {
73 if p.target == 3306 {
74 match &p.published {
75 Some(dct::PublishedPort::Single(port_num)) => Some(*port_num),
76 Some(dct::PublishedPort::Range(port_str)) => {
77 port_str.parse::<u16>().ok()
78 }
79 None => None,
80 }
81 } else {
82 None
83 }
84 })
85 .ok_or_else(|| {
86 anyhow!("No mapping to container port 3306 found in 'mysql' service")
87 })?,
88 };
89
90 Ok(MySqlConfig {
91 host: "127.0.0.1".to_string(),
92 port,
93 user: config_map
94 .get("MYSQL_USER")
95 .cloned()
96 .unwrap_or_else(|| "root".to_string()),
97 password: config_map
98 .get("MYSQL_PASSWORD")
99 .cloned()
100 .unwrap_or_else(|| "root".to_string()),
101 database: config_map
102 .get("MYSQL_DATABASE")
103 .cloned()
104 .unwrap_or_else(|| "agent_platform".to_string()),
105 })
106 }
107
108 fn to_url(&self) -> String {
110 format!(
111 "mysql://{}:{}@{}:{}/{}",
112 self.user, self.password, self.host, self.port, self.database
113 )
114 }
115}
116
117impl MySqlExecutor {
118 pub fn new(config: MySqlConfig) -> Self {
120 let opts = Opts::from_url(&config.to_url()).unwrap();
121 let pool = Pool::new(opts);
122 Self { pool, config }
123 }
124
125 pub async fn test_connection(&self) -> Result<(), mysql_async::Error> {
127 let mut conn = self.pool.get_conn().await?;
128 conn.query_drop("SELECT 1").await?;
129 Ok(())
130 }
131
132 pub async fn execute_single(&self, sql: &str) -> Result<u64, mysql_async::Error> {
134 let mut conn = self.pool.get_conn().await?;
135 let result = conn.query_iter(sql).await?;
136 Ok(result.affected_rows())
137 }
138
139 pub async fn execute_diff_sql(&self, sql_content: &str) -> Result<Vec<String>, anyhow::Error> {
142 self.execute_diff_sql_with_retry(sql_content, 1).await
143 }
144
145 pub async fn execute_diff_sql_with_retry(
147 &self,
148 sql_content: &str,
149 max_retries: u8,
150 ) -> Result<Vec<String>, anyhow::Error> {
151 let sql_lines = self.parse_sql_commands(sql_content);
152 let mut results = Vec::new();
153 let mut last_error: Option<mysql_async::Error> = None;
154
155 for attempt in 0..=max_retries {
156 if attempt > 0 {
157 tokio::time::sleep(std::time::Duration::from_millis(500 * attempt as u64)).await;
158 results.push(format!("🔄 Retrying attempt {attempt}/{max_retries}..."));
159 }
160
161 let mut conn = self.pool.get_conn().await?;
162 let mut tx = conn.start_transaction(TxOpts::default()).await?;
163
164 let results_len_before_attempt = results.len();
166
167 match self
168 .execute_in_transaction(&mut tx, &sql_lines, &mut results)
169 .await
170 {
171 Ok(_) => {
172 tx.commit().await?;
173 results.insert(0, "✅ Diff SQL executed successfully".to_string());
174 return Ok(results);
175 }
176 Err(e) => {
177 tx.rollback().await?;
178 results.truncate(results_len_before_attempt);
180 results.push(format!("❌ Attempt {} failed: {}", attempt + 1, e));
181 last_error = Some(e);
182 }
183 }
184 }
185
186 Err(anyhow::anyhow!(
187 "❌ SQL execution failed after {} attempts. Last error: {}",
188 max_retries + 1,
189 last_error.unwrap()
190 ))
191 }
192
193 async fn execute_in_transaction<'a>(
195 &self,
196 tx: &mut Transaction<'a>,
197 lines: &[String],
198 results: &mut Vec<String>,
199 ) -> Result<(), mysql_async::Error> {
200 for (idx, sql) in lines.iter().enumerate() {
201 if sql.starts_with("--") || sql.trim().is_empty() {
202 continue;
203 }
204
205 tx.query_drop(sql).await?;
206 results.push(format!("[{}] ✅ {}", idx + 1, sql));
207 }
208 Ok(())
209 }
210
211 fn parse_sql_commands(&self, sql_content: &str) -> Vec<String> {
213 let mut commands = Vec::new();
214 let mut current_command = String::new();
215
216 for line in sql_content.lines() {
217 let line = line.trim();
218
219 if line.starts_with("--") || line.is_empty() {
220 continue;
221 }
222
223 current_command.push_str(line);
224 current_command.push(' ');
225
226 if line.ends_with(';') || line.ends_with("ENGINE=InnoDB;") || line.ends_with(");") {
228 commands.push(current_command.trim().to_string());
229 current_command.clear();
230 }
231 }
232
233 if !current_command.trim().is_empty() {
234 commands.push(current_command.trim().to_string());
235 }
236
237 commands
238 }
239
240 pub async fn get_table_info(&self, table_name: &str) -> Result<(), mysql_async::Error> {
242 let mut conn = self.pool.get_conn().await?;
243 let results: Vec<Row> = conn.query(format!("DESCRIBE {table_name}")).await?;
244
245 for row in results {
246 println!("{row:?}");
247 }
248 Ok(())
249 }
250
251 pub async fn fetch_live_schema(
253 &self,
254 ) -> Result<std::collections::HashMap<String, TableDefinition>, anyhow::Error> {
255 let (tables, _sql) = self.fetch_live_schema_with_sql().await?;
256 Ok(tables)
257 }
258
259 pub async fn fetch_live_schema_with_sql(
262 &self,
263 ) -> Result<(std::collections::HashMap<String, TableDefinition>, String), anyhow::Error> {
264 use crate::sql_diff::parse_sql_tables;
265
266 let mut conn = self.pool.get_conn().await?;
267
268 let table_names: Vec<String> = conn
270 .exec(
271 r#"SELECT TABLE_NAME
272 FROM INFORMATION_SCHEMA.TABLES
273 WHERE TABLE_SCHEMA = ?
274 ORDER BY TABLE_NAME"#,
275 (self.config.database.clone(),),
276 )
277 .await?
278 .into_iter()
279 .map(|row| {
280 let (name,): (String,) = from_row(row);
281 name
282 })
283 .collect();
284
285 let mut create_sqls = String::new();
287 for table in &table_names {
288 let query = format!("SHOW CREATE TABLE `{}`", table);
289 let row: Row = conn.exec_first(query, ()).await?.ok_or_else(|| {
290 anyhow::anyhow!(format!(
291 "Failed to get CREATE statement for table: {}",
292 table
293 ))
294 })?;
295 let (_tbl_name, create_stmt): (String, String) = from_row(row);
297 create_sqls.push_str(&create_stmt);
298
299 if !create_stmt.trim().ends_with(';') {
301 create_sqls.push(';');
302 }
303 create_sqls.push_str("\n\n");
304 }
305
306 let tables = parse_sql_tables(&create_sqls)
308 .map_err(|e| anyhow::anyhow!(format!("Failed to parse online DDL: {}", e)))?;
309
310 Ok((tables, create_sqls))
311 }
312
313 pub async fn verify_execution(
315 &self,
316 _expected_changes: &str,
317 ) -> Result<bool, mysql_async::Error> {
318 let mut conn = self.pool.get_conn().await?;
319
320 let result: Option<(i32,)> = conn.query_first("SELECT 1 as verification_status").await?;
322 if let Some((1,)) = result {
323 Ok(true)
324 } else {
325 Ok(false)
326 }
327 }
328
329 pub async fn health_check(&self) -> HealthStatus {
331 match self.test_connection().await {
332 Ok(_) => HealthStatus::Healthy,
333 Err(e) => HealthStatus::Failed(e.to_string()),
334 }
335 }
336}
337
338#[derive(Debug, Clone)]
340pub enum HealthStatus {
341 Healthy,
342 Failed(String),
343}
344
345#[derive(Debug, Clone)]
347pub struct ExecutionResult {
348 pub sql: String,
349 pub status: bool,
350 pub rows_affected: Option<u64>,
351 pub error: Option<String>,
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[tokio::test]
359 async fn test_mysql_connection() {
360 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
361 let compose_path = std::path::Path::new(&manifest_dir).join("fixtures/docker-compose.yml");
362 let env_path = std::path::Path::new(&manifest_dir).join("fixtures/.env");
363 let config = MySqlConfig::for_container(
364 Some(compose_path.to_str().unwrap()),
365 Some(env_path.to_str().unwrap()),
366 )
367 .await
368 .unwrap();
369 let executor = MySqlExecutor::new(config);
370 if executor.test_connection().await.is_ok() {
371 executor
373 .execute_single("CREATE DATABASE IF NOT EXISTS test_db")
374 .await
375 .unwrap();
376
377 executor.execute_single("USE test_db").await.unwrap();
378
379 executor
380 .execute_single(
381 "CREATE TABLE IF NOT EXISTS test_table (id INT PRIMARY KEY, name VARCHAR(255))",
382 )
383 .await
384 .unwrap();
385
386 let results = executor
387 .execute_diff_sql("CREATE TABLE IF NOT EXISTS users (id INT PRIMARY KEY AUTO_INCREMENT, name VARCHAR(100)); \
388 ALTER TABLE users ADD COLUMN email VARCHAR(255); \
389 CREATE INDEX idx_name ON users(name);")
390 .await
391 .unwrap();
392
393 assert!(!results.is_empty());
394 println!("✅ MySQL执行器测试通过");
395
396 executor
398 .execute_single("DROP DATABASE IF EXISTS test_db")
399 .await
400 .unwrap();
401 } else {
402 println!("⚠️ MySQL容器未运行,跳过测试");
403 }
404 }
405
406 #[tokio::test]
407 async fn test_parse_sql_commands() {
408 let content = "-- 注释\n\
409 CREATE TABLE users (id INT);\n\
410 ALTER TABLE users ADD COLUMN name VARCHAR(100);\n\
411 CREATE INDEX idx_name ON users(name);";
412
413 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
414 let compose_path = std::path::Path::new(&manifest_dir).join("fixtures/docker-compose.yml");
415 let env_path = std::path::Path::new(&manifest_dir).join("fixtures/.env");
416 let config = MySqlConfig::for_container(
417 Some(compose_path.to_str().unwrap()),
418 Some(env_path.to_str().unwrap()),
419 )
420 .await
421 .unwrap();
422 let executor = MySqlExecutor::new(config);
423
424 let commands = executor.parse_sql_commands(content);
425 assert_eq!(commands.len(), 3);
426 assert!(commands[0].contains("CREATE TABLE users"));
427 assert!(commands[1].contains("ALTER TABLE users ADD COLUMN name"));
428 }
429
430 #[tokio::test]
431 async fn test_empty_and_comments() {
432 let content = "-- This is a comment\n\nCREATE TABLE test (id INT);\n-- Another comment";
433 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
434 let compose_path = std::path::Path::new(&manifest_dir).join("fixtures/docker-compose.yml");
435 let env_path = std::path::Path::new(&manifest_dir).join("fixtures/.env");
436 let config = MySqlConfig::for_container(
437 Some(compose_path.to_str().unwrap()),
438 Some(env_path.to_str().unwrap()),
439 )
440 .await
441 .unwrap();
442 let executor = MySqlExecutor::new(config);
443
444 let commands = executor.parse_sql_commands(content);
445 assert_eq!(commands.len(), 1);
446 assert_eq!(commands[0], "CREATE TABLE test (id INT);");
447 }
448
449 #[test]
450 fn test_table_name_normalization() {
451 use crate::sql_diff::parse_sql_tables;
453
454 let sql_with_backticks = "CREATE TABLE `test_table` (\n `id` int NOT NULL AUTO_INCREMENT,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB;";
456
457 let sql_without_backticks = "CREATE TABLE test_table (\n id int NOT NULL AUTO_INCREMENT,\n PRIMARY KEY (id)\n) ENGINE=InnoDB;";
459
460 let tables1 = parse_sql_tables(sql_with_backticks).expect("解析带反引号的 SQL 失败");
461 let tables2 = parse_sql_tables(sql_without_backticks).expect("解析不带反引号的 SQL 失败");
462
463 assert!(
465 tables1.contains_key("test_table"),
466 "带反引号的表名应该被标准化为 test_table"
467 );
468 assert!(
469 tables2.contains_key("test_table"),
470 "不带反引号的表名应该是 test_table"
471 );
472
473 assert!(
475 !tables1.contains_key("`test_table`"),
476 "不应该有带反引号的表名作为 key"
477 );
478 assert!(
479 !tables2.contains_key("`test_table`"),
480 "不应该有带反引号的表名作为 key"
481 );
482
483 println!("✅ 表名标准化测试通过");
484 }
485
486 #[test]
487 fn test_sql_diff_with_same_tables() {
488 use crate::sql_diff::{generate_schema_diff, parse_sql_tables};
490
491 let mysql_sql = "CREATE TABLE `custom_page_config` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `name` varchar(255) NOT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB;";
493
494 let file_sql = "CREATE TABLE custom_page_config (\n id bigint NOT NULL AUTO_INCREMENT,\n name varchar(255) NOT NULL,\n PRIMARY KEY (id)\n) ENGINE=InnoDB;";
496
497 let mysql_tables = parse_sql_tables(mysql_sql).expect("解析 MySQL SQL 失败");
498 let file_tables = parse_sql_tables(file_sql).expect("解析文件 SQL 失败");
499
500 println!("MySQL 表: {:?}", mysql_tables.keys().collect::<Vec<_>>());
501 println!("文件表: {:?}", file_tables.keys().collect::<Vec<_>>());
502
503 let (diff_sql, description) =
505 generate_schema_diff(Some(mysql_sql), file_sql, Some("在线架构"), "目标版本")
506 .expect("生成差异 SQL 失败");
507
508 println!("差异描述: {}", description);
509 println!("差异 SQL:\n{}", diff_sql);
510
511 let meaningful_lines: Vec<&str> = diff_sql
513 .lines()
514 .filter(|line| !line.trim().is_empty() && !line.trim().starts_with("--"))
515 .collect();
516
517 assert!(
518 meaningful_lines.is_empty(),
519 "相同的表不应该产生差异 SQL,但生成了: {:?}",
520 meaningful_lines
521 );
522
523 println!("✅ SQL diff 测试通过:相同的表没有产生差异");
524 }
525
526 #[test]
527 fn test_create_table_concatenation_with_semicolons() {
528 let mut create_sqls = String::new();
530
531 let stmt1 = "CREATE TABLE `agent_config` (\n `id` int NOT NULL AUTO_INCREMENT,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB";
533 create_sqls.push_str(stmt1);
534
535 if !stmt1.trim().ends_with(';') {
537 create_sqls.push(';');
538 }
539 create_sqls.push_str("\n\n");
540
541 let stmt2 = "CREATE TABLE `agent_component_config` (\n `id` int NOT NULL AUTO_INCREMENT,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB";
543 create_sqls.push_str(stmt2);
544
545 if !stmt2.trim().ends_with(';') {
547 create_sqls.push(';');
548 }
549 create_sqls.push_str("\n\n");
550
551 println!("拼接后的 SQL:\n{}", create_sqls);
552
553 assert!(
555 create_sqls.contains("ENGINE=InnoDB;"),
556 "第一个表的语句应该以分号结尾"
557 );
558 assert!(
559 create_sqls.matches("ENGINE=InnoDB;").count() == 2,
560 "两个表的语句都应该以分号结尾"
561 );
562
563 use crate::sql_diff::parse_sql_tables;
565 let result = parse_sql_tables(&create_sqls);
566
567 if let Err(ref e) = result {
568 println!("解析错误: {}", e);
569 }
570
571 assert!(
572 result.is_ok(),
573 "拼接后的 SQL 应该可以被正确解析: {:?}",
574 result.err()
575 );
576
577 let tables = result.unwrap();
578 println!("解析出的表: {:?}", tables.keys().collect::<Vec<_>>());
579 assert_eq!(
580 tables.len(),
581 2,
582 "应该解析出 2 个表,实际解析出 {} 个",
583 tables.len()
584 );
585
586 let has_agent_config =
588 tables.contains_key("agent_config") || tables.contains_key("`agent_config`");
589 let has_agent_component_config = tables.contains_key("agent_component_config")
590 || tables.contains_key("`agent_component_config`");
591
592 assert!(has_agent_config, "应该包含 agent_config 表");
593 assert!(
594 has_agent_component_config,
595 "应该包含 agent_component_config 表"
596 );
597
598 println!("✅ CREATE TABLE 语句拼接测试通过");
599 }
600}