Skip to main content

client_core/
mysql_executor.rs

1use crate::container::DockerManager;
2use crate::sql_diff::{TableColumn, TableDefinition, TableIndex};
3use anyhow::{Context, Result, anyhow};
4use docker_compose_types as dct;
5use mysql_async::prelude::*;
6use mysql_async::{Opts, Pool, Row, Transaction, TxOpts, from_row};
7
8/// MySQL容器异步差异SQL执行器
9/// 专为Duck Client自动升级部署设计
10pub struct MySqlExecutor {
11    pool: Pool,
12    config: MySqlConfig,
13}
14
15/// MySQL配置适配现有系统
16#[derive(Debug, Clone)]
17pub struct MySqlConfig {
18    pub host: String,
19    pub port: u16,
20    pub user: String,
21    pub password: String,
22    pub database: String,
23}
24
25impl MySqlConfig {
26    /// 通过解析 docker-compose.yml 文件为容器环境适配配置
27    pub async fn for_container(compose_file: Option<&str>, env_file: Option<&str>) -> Result<Self> {
28        let docker_manager = match (compose_file, env_file) {
29            (Some(c), Some(e)) => DockerManager::with_project(c, e, None)?,
30            _ => {
31                return Err(anyhow!(
32                    "docker-compose.yml and .env paths are required to load Docker Compose configuration"
33                ));
34            }
35        };
36        let compose_config = docker_manager
37            .load_compose_config()
38            .context("Failed to load Docker Compose configuration")?;
39
40        let mysql_service = compose_config
41            .services
42            .0
43            .get("mysql")
44            .and_then(|s| s.as_ref())
45            .ok_or_else(|| anyhow!("'mysql' service not found in docker-compose.yml"))?;
46
47        let mut config_map = std::collections::HashMap::new();
48        if let dct::Environment::List(env_list) = &mysql_service.environment {
49            for item in env_list {
50                if let Some((key, value)) = item.split_once('=') {
51                    config_map.insert(key.to_string(), value.to_string());
52                }
53            }
54        }
55
56        let port = match &mysql_service.ports {
57            dct::Ports::Short(ports_list) => ports_list
58                .iter()
59                .find_map(|p| {
60                    let parts: Vec<&str> = p.split(':').collect();
61                    if parts.len() == 2 && parts[1] == "3306" {
62                        parts[0].parse::<u16>().ok()
63                    } else {
64                        None
65                    }
66                })
67                .ok_or_else(|| anyhow!("No mapping to container port 3306 found in 'mysql' service"))?,
68            dct::Ports::Long(ports_list) => ports_list
69                .iter()
70                .find_map(|p| {
71                    if p.target == 3306 {
72                        match &p.published {
73                            Some(dct::PublishedPort::Single(port_num)) => Some(*port_num),
74                            Some(dct::PublishedPort::Range(port_str)) => {
75                                port_str.parse::<u16>().ok()
76                            }
77                            None => None,
78                        }
79                    } else {
80                        None
81                    }
82                })
83                .ok_or_else(|| anyhow!("No mapping to container port 3306 found in 'mysql' service"))?,
84            _ => return Err(anyhow!("Unsupported ports format or undefined ports in 'mysql' service")),
85        };
86
87        Ok(MySqlConfig {
88            host: "127.0.0.1".to_string(),
89            port,
90            user: config_map
91                .get("MYSQL_USER")
92                .cloned()
93                .unwrap_or_else(|| "root".to_string()),
94            password: config_map
95                .get("MYSQL_PASSWORD")
96                .cloned()
97                .unwrap_or_else(|| "root".to_string()),
98            database: config_map
99                .get("MYSQL_DATABASE")
100                .cloned()
101                .unwrap_or_else(|| "agent_platform".to_string()),
102        })
103    }
104
105    /// 生成连接URL
106    fn to_url(&self) -> String {
107        format!(
108            "mysql://{}:{}@{}:{}/{}",
109            self.user, self.password, self.host, self.port, self.database
110        )
111    }
112}
113
114impl MySqlExecutor {
115    /// 创建新的执行器
116    pub fn new(config: MySqlConfig) -> Self {
117        let opts = Opts::from_url(&config.to_url()).unwrap();
118        let pool = Pool::new(opts);
119        Self { pool, config }
120    }
121
122    /// 测试连接是否可用
123    pub async fn test_connection(&self) -> Result<(), mysql_async::Error> {
124        let mut conn = self.pool.get_conn().await?;
125        conn.query_drop("SELECT 1").await?;
126        Ok(())
127    }
128
129    /// 执行单个SQL语句
130    pub async fn execute_single(&self, sql: &str) -> Result<u64, mysql_async::Error> {
131        let mut conn = self.pool.get_conn().await?;
132        let result = conn.query_iter(sql).await?;
133        Ok(result.affected_rows())
134    }
135
136    /// 执行差异SQL内容(多语句支持)
137    /// 自动处理注释和空行,支持事务回滚
138    pub async fn execute_diff_sql(&self, sql_content: &str) -> Result<Vec<String>, anyhow::Error> {
139        self.execute_diff_sql_with_retry(sql_content, 1).await
140    }
141
142    /// 带重试机制的SQL执行
143    pub async fn execute_diff_sql_with_retry(
144        &self,
145        sql_content: &str,
146        max_retries: u8,
147    ) -> Result<Vec<String>, anyhow::Error> {
148        let sql_lines = self.parse_sql_commands(sql_content);
149        let mut results = Vec::new();
150        let mut last_error: Option<mysql_async::Error> = None;
151
152        for attempt in 0..=max_retries {
153            if attempt > 0 {
154                tokio::time::sleep(std::time::Duration::from_millis(500 * attempt as u64)).await;
155                results.push(format!("🔄 Retrying attempt {attempt}/{max_retries}..."));
156            }
157
158            let mut conn = self.pool.get_conn().await?;
159            let mut tx = conn.start_transaction(TxOpts::default()).await?;
160
161            // 记录本次尝试前的日志数量,如果失败可以回滚
162            let results_len_before_attempt = results.len();
163
164            match self
165                .execute_in_transaction(&mut tx, &sql_lines, &mut results)
166                .await
167            {
168                Ok(_) => {
169                    tx.commit().await?;
170                    results.insert(0, "✅ Diff SQL executed successfully".to_string());
171                    return Ok(results);
172                }
173                Err(e) => {
174                    tx.rollback().await?;
175                    // 移除本次失败尝试中添加的日志
176                    results.truncate(results_len_before_attempt);
177                    results.push(format!("❌ Attempt {} failed: {}", attempt + 1, e));
178                    last_error = Some(e);
179                }
180            }
181        }
182
183        Err(anyhow::anyhow!(
184            "❌ SQL execution failed after {} attempts. Last error: {}",
185            max_retries + 1,
186            last_error.unwrap()
187        ))
188    }
189
190    /// 执行在事务中的差异SQL
191    async fn execute_in_transaction<'a>(
192        &self,
193        tx: &mut Transaction<'a>,
194        lines: &[String],
195        results: &mut Vec<String>,
196    ) -> Result<(), mysql_async::Error> {
197        for (idx, sql) in lines.iter().enumerate() {
198            if sql.starts_with("--") || sql.trim().is_empty() {
199                continue;
200            }
201
202            tx.query_drop(sql).await?;
203            results.push(format!("[{}] ✅ {}", idx + 1, sql));
204        }
205        Ok(())
206    }
207
208    /// 解析SQL内容为可执行的命令列表
209    fn parse_sql_commands(&self, sql_content: &str) -> Vec<String> {
210        let mut commands = Vec::new();
211        let mut current_command = String::new();
212
213        for line in sql_content.lines() {
214            let line = line.trim();
215
216            if line.starts_with("--") || line.is_empty() {
217                continue;
218            }
219
220            current_command.push_str(line);
221            current_command.push(' ');
222
223            // 如果行的末尾是分号SQL结束
224            if line.ends_with(';') || line.ends_with("ENGINE=InnoDB;") || line.ends_with(");") {
225                commands.push(current_command.trim().to_string());
226                current_command.clear();
227            }
228        }
229
230        if !current_command.trim().is_empty() {
231            commands.push(current_command.trim().to_string());
232        }
233
234        commands
235    }
236
237    /// 获取数据库表结构信息
238    pub async fn get_table_info(&self, table_name: &str) -> Result<(), mysql_async::Error> {
239        let mut conn = self.pool.get_conn().await?;
240        let results: Vec<Row> = conn.query(format!("DESCRIBE {table_name}")).await?;
241
242        for row in results {
243            println!("{row:?}");
244        }
245        Ok(())
246    }
247
248    /// 抓取在线数据库架构:通过 SHOW CREATE TABLE 获取真实DDL,再用 sqlparser 解析为内部类型
249    pub async fn fetch_live_schema(
250        &self,
251    ) -> Result<std::collections::HashMap<String, TableDefinition>, anyhow::Error> {
252        let (tables, _sql) = self.fetch_live_schema_with_sql().await?;
253        Ok(tables)
254    }
255
256    /// 抓取在线数据库架构并返回原始 SQL
257    /// 返回:(解析后的表定义, 原始 CREATE TABLE SQL)
258    pub async fn fetch_live_schema_with_sql(
259        &self,
260    ) -> Result<(std::collections::HashMap<String, TableDefinition>, String), anyhow::Error> {
261        use crate::sql_diff::parse_sql_tables;
262
263        let mut conn = self.pool.get_conn().await?;
264
265        // 获取当前数据库所有表名
266        let table_names: Vec<String> = conn
267            .exec(
268                r#"SELECT TABLE_NAME
269                    FROM INFORMATION_SCHEMA.TABLES
270                    WHERE TABLE_SCHEMA = ?
271                    ORDER BY TABLE_NAME"#,
272                (self.config.database.clone(),),
273            )
274            .await?
275            .into_iter()
276            .map(|row| {
277                let (name,): (String,) = from_row(row);
278                name
279            })
280            .collect();
281
282        // 拼接所有表的 CREATE 语句
283        let mut create_sqls = String::new();
284        for table in &table_names {
285            let query = format!("SHOW CREATE TABLE `{}`", table);
286            let row: Row = conn
287                .exec_first(query, ())
288                .await?
289                .ok_or_else(|| anyhow::anyhow!(format!("Failed to get CREATE statement for table: {}", table)))?;
290            // MySQL返回两列:Table, Create Table
291            let (_tbl_name, create_stmt): (String, String) = from_row(row);
292            create_sqls.push_str(&create_stmt);
293
294            // 确保每个 CREATE TABLE 语句以分号结尾
295            if !create_stmt.trim().ends_with(';') {
296                create_sqls.push(';');
297            }
298            create_sqls.push_str("\n\n");
299        }
300
301        // 使用 sqlparser 解析 DDL,严格避免正则
302        let tables = parse_sql_tables(&create_sqls)
303            .map_err(|e| anyhow::anyhow!(format!("Failed to parse online DDL: {}", e)))?;
304
305        Ok((tables, create_sqls))
306    }
307
308    /// 验证执行结果
309    pub async fn verify_execution(
310        &self,
311        _expected_changes: &str,
312    ) -> Result<bool, mysql_async::Error> {
313        let mut conn = self.pool.get_conn().await?;
314
315        // 简单的执行确认
316        let result: Option<(i32,)> = conn.query_first("SELECT 1 as verification_status").await?;
317        if let Some((1,)) = result {
318            Ok(true)
319        } else {
320            Ok(false)
321        }
322    }
323
324    /// 检查数据库连接健康
325    pub async fn health_check(&self) -> HealthStatus {
326        match self.test_connection().await {
327            Ok(_) => HealthStatus::Healthy,
328            Err(e) => HealthStatus::Failed(e.to_string()),
329        }
330    }
331}
332
333/// 健康状态枚举
334#[derive(Debug, Clone)]
335pub enum HealthStatus {
336    Healthy,
337    Failed(String),
338}
339
340/// 执行结果记录
341#[derive(Debug, Clone)]
342pub struct ExecutionResult {
343    pub sql: String,
344    pub status: bool,
345    pub rows_affected: Option<u64>,
346    pub error: Option<String>,
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[tokio::test]
354    async fn test_mysql_connection() {
355        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
356        let compose_path = std::path::Path::new(&manifest_dir).join("fixtures/docker-compose.yml");
357        let env_path = std::path::Path::new(&manifest_dir).join("fixtures/.env");
358        let config = MySqlConfig::for_container(
359            Some(compose_path.to_str().unwrap()),
360            Some(env_path.to_str().unwrap()),
361        )
362        .await
363        .unwrap();
364        let executor = MySqlExecutor::new(config);
365        if executor.test_connection().await.is_ok() {
366            // 测试真实执行
367            executor
368                .execute_single("CREATE DATABASE IF NOT EXISTS test_db")
369                .await
370                .unwrap();
371
372            executor.execute_single("USE test_db").await.unwrap();
373
374            executor
375                .execute_single(
376                    "CREATE TABLE IF NOT EXISTS test_table (id INT PRIMARY KEY, name VARCHAR(255))",
377                )
378                .await
379                .unwrap();
380
381            let results = executor
382                .execute_diff_sql("CREATE TABLE IF NOT EXISTS users (id INT PRIMARY KEY AUTO_INCREMENT, name VARCHAR(100)); \
383                                 ALTER TABLE users ADD COLUMN email VARCHAR(255); \
384                                 CREATE INDEX idx_name ON users(name);")
385                .await
386                .unwrap();
387
388            assert!(!results.is_empty());
389            println!("✅ MySQL执行器测试通过");
390
391            // 清理
392            executor
393                .execute_single("DROP DATABASE IF EXISTS test_db")
394                .await
395                .unwrap();
396        } else {
397            println!("⚠️ MySQL容器未运行,跳过测试");
398        }
399    }
400
401    #[tokio::test]
402    async fn test_parse_sql_commands() {
403        let content = "-- 注释\n\
404                      CREATE TABLE users (id INT);\n\
405                      ALTER TABLE users ADD COLUMN name VARCHAR(100);\n\
406                      CREATE INDEX idx_name ON users(name);";
407
408        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
409        let compose_path = std::path::Path::new(&manifest_dir).join("fixtures/docker-compose.yml");
410        let env_path = std::path::Path::new(&manifest_dir).join("fixtures/.env");
411        let config = MySqlConfig::for_container(
412            Some(compose_path.to_str().unwrap()),
413            Some(env_path.to_str().unwrap()),
414        )
415        .await
416        .unwrap();
417        let executor = MySqlExecutor::new(config);
418
419        let commands = executor.parse_sql_commands(content);
420        assert_eq!(commands.len(), 3);
421        assert!(commands[0].contains("CREATE TABLE users"));
422        assert!(commands[1].contains("ALTER TABLE users ADD COLUMN name"));
423    }
424
425    #[tokio::test]
426    async fn test_empty_and_comments() {
427        let content = "-- This is a comment\n\nCREATE TABLE test (id INT);\n-- Another comment";
428        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
429        let compose_path = std::path::Path::new(&manifest_dir).join("fixtures/docker-compose.yml");
430        let env_path = std::path::Path::new(&manifest_dir).join("fixtures/.env");
431        let config = MySqlConfig::for_container(
432            Some(compose_path.to_str().unwrap()),
433            Some(env_path.to_str().unwrap()),
434        )
435        .await
436        .unwrap();
437        let executor = MySqlExecutor::new(config);
438
439        let commands = executor.parse_sql_commands(content);
440        assert_eq!(commands.len(), 1);
441        assert_eq!(commands[0], "CREATE TABLE test (id INT);");
442    }
443
444    #[test]
445    fn test_table_name_normalization() {
446        // 测试表名标准化:确保带反引号和不带反引号的表名被识别为同一个表
447        use crate::sql_diff::parse_sql_tables;
448
449        // SQL 1: 带反引号的表名
450        let sql_with_backticks = "CREATE TABLE `test_table` (\n  `id` int NOT NULL AUTO_INCREMENT,\n  PRIMARY KEY (`id`)\n) ENGINE=InnoDB;";
451
452        // SQL 2: 不带反引号的表名
453        let sql_without_backticks = "CREATE TABLE test_table (\n  id int NOT NULL AUTO_INCREMENT,\n  PRIMARY KEY (id)\n) ENGINE=InnoDB;";
454
455        let tables1 = parse_sql_tables(sql_with_backticks).expect("解析带反引号的 SQL 失败");
456        let tables2 = parse_sql_tables(sql_without_backticks).expect("解析不带反引号的 SQL 失败");
457
458        // 两种情况都应该解析出相同的表名(不带反引号)
459        assert!(
460            tables1.contains_key("test_table"),
461            "带反引号的表名应该被标准化为 test_table"
462        );
463        assert!(
464            tables2.contains_key("test_table"),
465            "不带反引号的表名应该是 test_table"
466        );
467
468        // 确保不会有带反引号的 key
469        assert!(
470            !tables1.contains_key("`test_table`"),
471            "不应该有带反引号的表名作为 key"
472        );
473        assert!(
474            !tables2.contains_key("`test_table`"),
475            "不应该有带反引号的表名作为 key"
476        );
477
478        println!("✅ 表名标准化测试通过");
479    }
480
481    #[test]
482    fn test_sql_diff_with_same_tables() {
483        // 测试 SQL diff:模拟从 MySQL 读取的表(带反引号)与文件中的表(不带反引号)
484        use crate::sql_diff::{generate_schema_diff, parse_sql_tables};
485
486        // 模拟从 MySQL SHOW CREATE TABLE 返回的 SQL(带反引号)
487        let mysql_sql = "CREATE TABLE `custom_page_config` (\n  `id` bigint NOT NULL AUTO_INCREMENT,\n  `name` varchar(255) NOT NULL,\n  PRIMARY KEY (`id`)\n) ENGINE=InnoDB;";
488
489        // 模拟从文件读取的 SQL(不带反引号)
490        let file_sql = "CREATE TABLE custom_page_config (\n  id bigint NOT NULL AUTO_INCREMENT,\n  name varchar(255) NOT NULL,\n  PRIMARY KEY (id)\n) ENGINE=InnoDB;";
491
492        let mysql_tables = parse_sql_tables(mysql_sql).expect("解析 MySQL SQL 失败");
493        let file_tables = parse_sql_tables(file_sql).expect("解析文件 SQL 失败");
494
495        println!("MySQL 表: {:?}", mysql_tables.keys().collect::<Vec<_>>());
496        println!("文件表: {:?}", file_tables.keys().collect::<Vec<_>>());
497
498        // 生成差异 SQL(使用 SQL 字符串作为参数)
499        let (diff_sql, description) =
500            generate_schema_diff(Some(mysql_sql), file_sql, Some("在线架构"), "目标版本")
501                .expect("生成差异 SQL 失败");
502
503        println!("差异描述: {}", description);
504        println!("差异 SQL:\n{}", diff_sql);
505
506        // 由于两个表结构相同,不应该有任何差异
507        let meaningful_lines: Vec<&str> = diff_sql
508            .lines()
509            .filter(|line| !line.trim().is_empty() && !line.trim().starts_with("--"))
510            .collect();
511
512        assert!(
513            meaningful_lines.is_empty(),
514            "相同的表不应该产生差异 SQL,但生成了: {:?}",
515            meaningful_lines
516        );
517
518        println!("✅ SQL diff 测试通过:相同的表没有产生差异");
519    }
520
521    #[test]
522    fn test_create_table_concatenation_with_semicolons() {
523        // 模拟从 MySQL SHOW CREATE TABLE 返回的多个语句(没有分号)
524        let mut create_sqls = String::new();
525
526        // 模拟第一个表的 CREATE 语句(没有分号)
527        let stmt1 = "CREATE TABLE `agent_config` (\n  `id` int NOT NULL AUTO_INCREMENT,\n  PRIMARY KEY (`id`)\n) ENGINE=InnoDB";
528        create_sqls.push_str(stmt1);
529
530        // 添加分号(这是我们的修复)
531        if !stmt1.trim().ends_with(';') {
532            create_sqls.push(';');
533        }
534        create_sqls.push_str("\n\n");
535
536        // 模拟第二个表的 CREATE 语句(没有分号)
537        let stmt2 = "CREATE TABLE `agent_component_config` (\n  `id` int NOT NULL AUTO_INCREMENT,\n  PRIMARY KEY (`id`)\n) ENGINE=InnoDB";
538        create_sqls.push_str(stmt2);
539
540        // 添加分号
541        if !stmt2.trim().ends_with(';') {
542            create_sqls.push(';');
543        }
544        create_sqls.push_str("\n\n");
545
546        println!("拼接后的 SQL:\n{}", create_sqls);
547
548        // 验证结果:每个 CREATE TABLE 语句都应该以分号结尾
549        assert!(
550            create_sqls.contains("ENGINE=InnoDB;"),
551            "第一个表的语句应该以分号结尾"
552        );
553        assert!(
554            create_sqls.matches("ENGINE=InnoDB;").count() == 2,
555            "两个表的语句都应该以分号结尾"
556        );
557
558        // 验证可以被 sqlparser 正确解析
559        use crate::sql_diff::parse_sql_tables;
560        let result = parse_sql_tables(&create_sqls);
561
562        if let Err(ref e) = result {
563            println!("解析错误: {}", e);
564        }
565
566        assert!(
567            result.is_ok(),
568            "拼接后的 SQL 应该可以被正确解析: {:?}",
569            result.err()
570        );
571
572        let tables = result.unwrap();
573        println!("解析出的表: {:?}", tables.keys().collect::<Vec<_>>());
574        assert_eq!(
575            tables.len(),
576            2,
577            "应该解析出 2 个表,实际解析出 {} 个",
578            tables.len()
579        );
580
581        // 表名可能带反引号,所以检查两种情况
582        let has_agent_config =
583            tables.contains_key("agent_config") || tables.contains_key("`agent_config`");
584        let has_agent_component_config = tables.contains_key("agent_component_config")
585            || tables.contains_key("`agent_component_config`");
586
587        assert!(has_agent_config, "应该包含 agent_config 表");
588        assert!(
589            has_agent_component_config,
590            "应该包含 agent_component_config 表"
591        );
592
593        println!("✅ CREATE TABLE 语句拼接测试通过");
594    }
595}