Skip to main content

client_core/
mysql_executor.rs

1use crate::container::DockerManager;
2use crate::sql_diff::TableDefinition;
3use anyhow::{Context, Result, anyhow};
4use docker_compose_types as dct;
5use mysql_async::prelude::*;
6use mysql_async::{Opts, Pool, Row, Transaction, TxOpts, from_row};
7
8/// MySQL容器异步差异SQL执行器
9/// 专为Duck Client自动升级部署设计
10pub struct MySqlExecutor {
11    pool: Pool,
12    config: MySqlConfig,
13}
14
15/// MySQL配置适配现有系统
16#[derive(Debug, Clone)]
17pub struct MySqlConfig {
18    pub host: String,
19    pub port: u16,
20    pub user: String,
21    pub password: String,
22    pub database: String,
23}
24
25impl MySqlConfig {
26    /// 通过解析 docker-compose.yml 文件为容器环境适配配置
27    pub async fn for_container(compose_file: Option<&str>, env_file: Option<&str>) -> Result<Self> {
28        let docker_manager = match (compose_file, env_file) {
29            (Some(c), Some(e)) => DockerManager::with_project(c, e, None)?,
30            _ => {
31                return Err(anyhow!(
32                    "docker-compose.yml and .env paths are required to load Docker Compose configuration"
33                ));
34            }
35        };
36        let compose_config = docker_manager
37            .load_compose_config()
38            .context("Failed to load Docker Compose configuration")?;
39
40        let mysql_service = compose_config
41            .services
42            .0
43            .get("mysql")
44            .and_then(|s| s.as_ref())
45            .ok_or_else(|| anyhow!("'mysql' service not found in docker-compose.yml"))?;
46
47        let mut config_map = std::collections::HashMap::new();
48        if let dct::Environment::List(env_list) = &mysql_service.environment {
49            for item in env_list {
50                if let Some((key, value)) = item.split_once('=') {
51                    config_map.insert(key.to_string(), value.to_string());
52                }
53            }
54        }
55
56        let port = match &mysql_service.ports {
57            dct::Ports::Short(ports_list) => ports_list
58                .iter()
59                .find_map(|p| {
60                    let parts: Vec<&str> = p.split(':').collect();
61                    if parts.len() == 2 && parts[1] == "3306" {
62                        parts[0].parse::<u16>().ok()
63                    } else {
64                        None
65                    }
66                })
67                .ok_or_else(|| {
68                    anyhow!("No mapping to container port 3306 found in 'mysql' service")
69                })?,
70            dct::Ports::Long(ports_list) => ports_list
71                .iter()
72                .find_map(|p| {
73                    if p.target == 3306 {
74                        match &p.published {
75                            Some(dct::PublishedPort::Single(port_num)) => Some(*port_num),
76                            Some(dct::PublishedPort::Range(port_str)) => {
77                                port_str.parse::<u16>().ok()
78                            }
79                            None => None,
80                        }
81                    } else {
82                        None
83                    }
84                })
85                .ok_or_else(|| {
86                    anyhow!("No mapping to container port 3306 found in 'mysql' service")
87                })?,
88        };
89
90        Ok(MySqlConfig {
91            host: "127.0.0.1".to_string(),
92            port,
93            user: config_map
94                .get("MYSQL_USER")
95                .cloned()
96                .unwrap_or_else(|| "root".to_string()),
97            password: config_map
98                .get("MYSQL_PASSWORD")
99                .cloned()
100                .unwrap_or_else(|| "root".to_string()),
101            database: config_map
102                .get("MYSQL_DATABASE")
103                .cloned()
104                .unwrap_or_else(|| "agent_platform".to_string()),
105        })
106    }
107
108    /// 生成连接URL
109    fn to_url(&self) -> String {
110        format!(
111            "mysql://{}:{}@{}:{}/{}",
112            self.user, self.password, self.host, self.port, self.database
113        )
114    }
115}
116
117impl MySqlExecutor {
118    /// 创建新的执行器
119    pub fn new(config: MySqlConfig) -> Self {
120        let opts = Opts::from_url(&config.to_url()).unwrap();
121        let pool = Pool::new(opts);
122        Self { pool, config }
123    }
124
125    /// 测试连接是否可用
126    pub async fn test_connection(&self) -> Result<(), mysql_async::Error> {
127        let mut conn = self.pool.get_conn().await?;
128        conn.query_drop("SELECT 1").await?;
129        Ok(())
130    }
131
132    /// 执行单个SQL语句
133    pub async fn execute_single(&self, sql: &str) -> Result<u64, mysql_async::Error> {
134        let mut conn = self.pool.get_conn().await?;
135        let result = conn.query_iter(sql).await?;
136        Ok(result.affected_rows())
137    }
138
139    /// 执行差异SQL内容(多语句支持)
140    /// 自动处理注释和空行,支持事务回滚
141    pub async fn execute_diff_sql(&self, sql_content: &str) -> Result<Vec<String>, anyhow::Error> {
142        self.execute_diff_sql_with_retry(sql_content, 1).await
143    }
144
145    /// 带重试机制的SQL执行
146    pub async fn execute_diff_sql_with_retry(
147        &self,
148        sql_content: &str,
149        max_retries: u8,
150    ) -> Result<Vec<String>, anyhow::Error> {
151        let sql_lines = self.parse_sql_commands(sql_content);
152        let mut results = Vec::new();
153        let mut last_error: Option<mysql_async::Error> = None;
154
155        for attempt in 0..=max_retries {
156            if attempt > 0 {
157                tokio::time::sleep(std::time::Duration::from_millis(500 * attempt as u64)).await;
158                results.push(format!("🔄 Retrying attempt {attempt}/{max_retries}..."));
159            }
160
161            let mut conn = self.pool.get_conn().await?;
162            let mut tx = conn.start_transaction(TxOpts::default()).await?;
163
164            // 记录本次尝试前的日志数量,如果失败可以回滚
165            let results_len_before_attempt = results.len();
166
167            match self
168                .execute_in_transaction(&mut tx, &sql_lines, &mut results)
169                .await
170            {
171                Ok(_) => {
172                    tx.commit().await?;
173                    results.insert(0, "✅ Diff SQL executed successfully".to_string());
174                    return Ok(results);
175                }
176                Err(e) => {
177                    tx.rollback().await?;
178                    // 移除本次失败尝试中添加的日志
179                    results.truncate(results_len_before_attempt);
180                    results.push(format!("❌ Attempt {} failed: {}", attempt + 1, e));
181                    last_error = Some(e);
182                }
183            }
184        }
185
186        Err(anyhow::anyhow!(
187            "❌ SQL execution failed after {} attempts. Last error: {}",
188            max_retries + 1,
189            last_error.unwrap()
190        ))
191    }
192
193    /// 执行在事务中的差异SQL
194    async fn execute_in_transaction<'a>(
195        &self,
196        tx: &mut Transaction<'a>,
197        lines: &[String],
198        results: &mut Vec<String>,
199    ) -> Result<(), mysql_async::Error> {
200        for (idx, sql) in lines.iter().enumerate() {
201            if sql.starts_with("--") || sql.trim().is_empty() {
202                continue;
203            }
204
205            tx.query_drop(sql).await?;
206            results.push(format!("[{}] ✅ {}", idx + 1, sql));
207        }
208        Ok(())
209    }
210
211    /// 解析SQL内容为可执行的命令列表
212    fn parse_sql_commands(&self, sql_content: &str) -> Vec<String> {
213        let mut commands = Vec::new();
214        let mut current_command = String::new();
215
216        for line in sql_content.lines() {
217            let line = line.trim();
218
219            if line.starts_with("--") || line.is_empty() {
220                continue;
221            }
222
223            current_command.push_str(line);
224            current_command.push(' ');
225
226            // 如果行的末尾是分号SQL结束
227            if line.ends_with(';') || line.ends_with("ENGINE=InnoDB;") || line.ends_with(");") {
228                commands.push(current_command.trim().to_string());
229                current_command.clear();
230            }
231        }
232
233        if !current_command.trim().is_empty() {
234            commands.push(current_command.trim().to_string());
235        }
236
237        commands
238    }
239
240    /// 获取数据库表结构信息
241    pub async fn get_table_info(&self, table_name: &str) -> Result<(), mysql_async::Error> {
242        let mut conn = self.pool.get_conn().await?;
243        let results: Vec<Row> = conn.query(format!("DESCRIBE {table_name}")).await?;
244
245        for row in results {
246            println!("{row:?}");
247        }
248        Ok(())
249    }
250
251    /// 抓取在线数据库架构:通过 SHOW CREATE TABLE 获取真实DDL,再用 sqlparser 解析为内部类型
252    pub async fn fetch_live_schema(
253        &self,
254    ) -> Result<std::collections::HashMap<String, TableDefinition>, anyhow::Error> {
255        let (tables, _sql) = self.fetch_live_schema_with_sql().await?;
256        Ok(tables)
257    }
258
259    /// 抓取在线数据库架构并返回原始 SQL
260    /// 返回:(解析后的表定义, 原始 CREATE TABLE SQL)
261    pub async fn fetch_live_schema_with_sql(
262        &self,
263    ) -> Result<(std::collections::HashMap<String, TableDefinition>, String), anyhow::Error> {
264        use crate::sql_diff::parse_sql_tables;
265
266        let mut conn = self.pool.get_conn().await?;
267
268        // 获取当前数据库所有表名
269        let table_names: Vec<String> = conn
270            .exec(
271                r#"SELECT TABLE_NAME
272                    FROM INFORMATION_SCHEMA.TABLES
273                    WHERE TABLE_SCHEMA = ?
274                    ORDER BY TABLE_NAME"#,
275                (self.config.database.clone(),),
276            )
277            .await?
278            .into_iter()
279            .map(|row| {
280                let (name,): (String,) = from_row(row);
281                name
282            })
283            .collect();
284
285        // 拼接所有表的 CREATE 语句
286        let mut create_sqls = String::new();
287        for table in &table_names {
288            let query = format!("SHOW CREATE TABLE `{}`", table);
289            let row: Row = conn.exec_first(query, ()).await?.ok_or_else(|| {
290                anyhow::anyhow!(format!(
291                    "Failed to get CREATE statement for table: {}",
292                    table
293                ))
294            })?;
295            // MySQL返回两列:Table, Create Table
296            let (_tbl_name, create_stmt): (String, String) = from_row(row);
297            create_sqls.push_str(&create_stmt);
298
299            // 确保每个 CREATE TABLE 语句以分号结尾
300            if !create_stmt.trim().ends_with(';') {
301                create_sqls.push(';');
302            }
303            create_sqls.push_str("\n\n");
304        }
305
306        // 使用 sqlparser 解析 DDL,严格避免正则
307        let tables = parse_sql_tables(&create_sqls)
308            .map_err(|e| anyhow::anyhow!(format!("Failed to parse online DDL: {}", e)))?;
309
310        Ok((tables, create_sqls))
311    }
312
313    /// 验证执行结果
314    pub async fn verify_execution(
315        &self,
316        _expected_changes: &str,
317    ) -> Result<bool, mysql_async::Error> {
318        let mut conn = self.pool.get_conn().await?;
319
320        // 简单的执行确认
321        let result: Option<(i32,)> = conn.query_first("SELECT 1 as verification_status").await?;
322        if let Some((1,)) = result {
323            Ok(true)
324        } else {
325            Ok(false)
326        }
327    }
328
329    /// 检查数据库连接健康
330    pub async fn health_check(&self) -> HealthStatus {
331        match self.test_connection().await {
332            Ok(_) => HealthStatus::Healthy,
333            Err(e) => HealthStatus::Failed(e.to_string()),
334        }
335    }
336}
337
338/// 健康状态枚举
339#[derive(Debug, Clone)]
340pub enum HealthStatus {
341    Healthy,
342    Failed(String),
343}
344
345/// 执行结果记录
346#[derive(Debug, Clone)]
347pub struct ExecutionResult {
348    pub sql: String,
349    pub status: bool,
350    pub rows_affected: Option<u64>,
351    pub error: Option<String>,
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[tokio::test]
359    async fn test_mysql_connection() {
360        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
361        let compose_path = std::path::Path::new(&manifest_dir).join("fixtures/docker-compose.yml");
362        let env_path = std::path::Path::new(&manifest_dir).join("fixtures/.env");
363        let config = MySqlConfig::for_container(
364            Some(compose_path.to_str().unwrap()),
365            Some(env_path.to_str().unwrap()),
366        )
367        .await
368        .unwrap();
369        let executor = MySqlExecutor::new(config);
370        if executor.test_connection().await.is_ok() {
371            // 测试真实执行
372            executor
373                .execute_single("CREATE DATABASE IF NOT EXISTS test_db")
374                .await
375                .unwrap();
376
377            executor.execute_single("USE test_db").await.unwrap();
378
379            executor
380                .execute_single(
381                    "CREATE TABLE IF NOT EXISTS test_table (id INT PRIMARY KEY, name VARCHAR(255))",
382                )
383                .await
384                .unwrap();
385
386            let results = executor
387                .execute_diff_sql("CREATE TABLE IF NOT EXISTS users (id INT PRIMARY KEY AUTO_INCREMENT, name VARCHAR(100)); \
388                                 ALTER TABLE users ADD COLUMN email VARCHAR(255); \
389                                 CREATE INDEX idx_name ON users(name);")
390                .await
391                .unwrap();
392
393            assert!(!results.is_empty());
394            println!("✅ MySQL执行器测试通过");
395
396            // 清理
397            executor
398                .execute_single("DROP DATABASE IF EXISTS test_db")
399                .await
400                .unwrap();
401        } else {
402            println!("⚠️ MySQL容器未运行,跳过测试");
403        }
404    }
405
406    #[tokio::test]
407    async fn test_parse_sql_commands() {
408        let content = "-- 注释\n\
409                      CREATE TABLE users (id INT);\n\
410                      ALTER TABLE users ADD COLUMN name VARCHAR(100);\n\
411                      CREATE INDEX idx_name ON users(name);";
412
413        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
414        let compose_path = std::path::Path::new(&manifest_dir).join("fixtures/docker-compose.yml");
415        let env_path = std::path::Path::new(&manifest_dir).join("fixtures/.env");
416        let config = MySqlConfig::for_container(
417            Some(compose_path.to_str().unwrap()),
418            Some(env_path.to_str().unwrap()),
419        )
420        .await
421        .unwrap();
422        let executor = MySqlExecutor::new(config);
423
424        let commands = executor.parse_sql_commands(content);
425        assert_eq!(commands.len(), 3);
426        assert!(commands[0].contains("CREATE TABLE users"));
427        assert!(commands[1].contains("ALTER TABLE users ADD COLUMN name"));
428    }
429
430    #[tokio::test]
431    async fn test_empty_and_comments() {
432        let content = "-- This is a comment\n\nCREATE TABLE test (id INT);\n-- Another comment";
433        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
434        let compose_path = std::path::Path::new(&manifest_dir).join("fixtures/docker-compose.yml");
435        let env_path = std::path::Path::new(&manifest_dir).join("fixtures/.env");
436        let config = MySqlConfig::for_container(
437            Some(compose_path.to_str().unwrap()),
438            Some(env_path.to_str().unwrap()),
439        )
440        .await
441        .unwrap();
442        let executor = MySqlExecutor::new(config);
443
444        let commands = executor.parse_sql_commands(content);
445        assert_eq!(commands.len(), 1);
446        assert_eq!(commands[0], "CREATE TABLE test (id INT);");
447    }
448
449    #[test]
450    fn test_table_name_normalization() {
451        // 测试表名标准化:确保带反引号和不带反引号的表名被识别为同一个表
452        use crate::sql_diff::parse_sql_tables;
453
454        // SQL 1: 带反引号的表名
455        let sql_with_backticks = "CREATE TABLE `test_table` (\n  `id` int NOT NULL AUTO_INCREMENT,\n  PRIMARY KEY (`id`)\n) ENGINE=InnoDB;";
456
457        // SQL 2: 不带反引号的表名
458        let sql_without_backticks = "CREATE TABLE test_table (\n  id int NOT NULL AUTO_INCREMENT,\n  PRIMARY KEY (id)\n) ENGINE=InnoDB;";
459
460        let tables1 = parse_sql_tables(sql_with_backticks).expect("解析带反引号的 SQL 失败");
461        let tables2 = parse_sql_tables(sql_without_backticks).expect("解析不带反引号的 SQL 失败");
462
463        // 两种情况都应该解析出相同的表名(不带反引号)
464        assert!(
465            tables1.contains_key("test_table"),
466            "带反引号的表名应该被标准化为 test_table"
467        );
468        assert!(
469            tables2.contains_key("test_table"),
470            "不带反引号的表名应该是 test_table"
471        );
472
473        // 确保不会有带反引号的 key
474        assert!(
475            !tables1.contains_key("`test_table`"),
476            "不应该有带反引号的表名作为 key"
477        );
478        assert!(
479            !tables2.contains_key("`test_table`"),
480            "不应该有带反引号的表名作为 key"
481        );
482
483        println!("✅ 表名标准化测试通过");
484    }
485
486    #[test]
487    fn test_sql_diff_with_same_tables() {
488        // 测试 SQL diff:模拟从 MySQL 读取的表(带反引号)与文件中的表(不带反引号)
489        use crate::sql_diff::{generate_schema_diff, parse_sql_tables};
490
491        // 模拟从 MySQL SHOW CREATE TABLE 返回的 SQL(带反引号)
492        let mysql_sql = "CREATE TABLE `custom_page_config` (\n  `id` bigint NOT NULL AUTO_INCREMENT,\n  `name` varchar(255) NOT NULL,\n  PRIMARY KEY (`id`)\n) ENGINE=InnoDB;";
493
494        // 模拟从文件读取的 SQL(不带反引号)
495        let file_sql = "CREATE TABLE custom_page_config (\n  id bigint NOT NULL AUTO_INCREMENT,\n  name varchar(255) NOT NULL,\n  PRIMARY KEY (id)\n) ENGINE=InnoDB;";
496
497        let mysql_tables = parse_sql_tables(mysql_sql).expect("解析 MySQL SQL 失败");
498        let file_tables = parse_sql_tables(file_sql).expect("解析文件 SQL 失败");
499
500        println!("MySQL 表: {:?}", mysql_tables.keys().collect::<Vec<_>>());
501        println!("文件表: {:?}", file_tables.keys().collect::<Vec<_>>());
502
503        // 生成差异 SQL(使用 SQL 字符串作为参数)
504        let (diff_sql, description) =
505            generate_schema_diff(Some(mysql_sql), file_sql, Some("在线架构"), "目标版本")
506                .expect("生成差异 SQL 失败");
507
508        println!("差异描述: {}", description);
509        println!("差异 SQL:\n{}", diff_sql);
510
511        // 由于两个表结构相同,不应该有任何差异
512        let meaningful_lines: Vec<&str> = diff_sql
513            .lines()
514            .filter(|line| !line.trim().is_empty() && !line.trim().starts_with("--"))
515            .collect();
516
517        assert!(
518            meaningful_lines.is_empty(),
519            "相同的表不应该产生差异 SQL,但生成了: {:?}",
520            meaningful_lines
521        );
522
523        println!("✅ SQL diff 测试通过:相同的表没有产生差异");
524    }
525
526    #[test]
527    fn test_create_table_concatenation_with_semicolons() {
528        // 模拟从 MySQL SHOW CREATE TABLE 返回的多个语句(没有分号)
529        let mut create_sqls = String::new();
530
531        // 模拟第一个表的 CREATE 语句(没有分号)
532        let stmt1 = "CREATE TABLE `agent_config` (\n  `id` int NOT NULL AUTO_INCREMENT,\n  PRIMARY KEY (`id`)\n) ENGINE=InnoDB";
533        create_sqls.push_str(stmt1);
534
535        // 添加分号(这是我们的修复)
536        if !stmt1.trim().ends_with(';') {
537            create_sqls.push(';');
538        }
539        create_sqls.push_str("\n\n");
540
541        // 模拟第二个表的 CREATE 语句(没有分号)
542        let stmt2 = "CREATE TABLE `agent_component_config` (\n  `id` int NOT NULL AUTO_INCREMENT,\n  PRIMARY KEY (`id`)\n) ENGINE=InnoDB";
543        create_sqls.push_str(stmt2);
544
545        // 添加分号
546        if !stmt2.trim().ends_with(';') {
547            create_sqls.push(';');
548        }
549        create_sqls.push_str("\n\n");
550
551        println!("拼接后的 SQL:\n{}", create_sqls);
552
553        // 验证结果:每个 CREATE TABLE 语句都应该以分号结尾
554        assert!(
555            create_sqls.contains("ENGINE=InnoDB;"),
556            "第一个表的语句应该以分号结尾"
557        );
558        assert!(
559            create_sqls.matches("ENGINE=InnoDB;").count() == 2,
560            "两个表的语句都应该以分号结尾"
561        );
562
563        // 验证可以被 sqlparser 正确解析
564        use crate::sql_diff::parse_sql_tables;
565        let result = parse_sql_tables(&create_sqls);
566
567        if let Err(ref e) = result {
568            println!("解析错误: {}", e);
569        }
570
571        assert!(
572            result.is_ok(),
573            "拼接后的 SQL 应该可以被正确解析: {:?}",
574            result.err()
575        );
576
577        let tables = result.unwrap();
578        println!("解析出的表: {:?}", tables.keys().collect::<Vec<_>>());
579        assert_eq!(
580            tables.len(),
581            2,
582            "应该解析出 2 个表,实际解析出 {} 个",
583            tables.len()
584        );
585
586        // 表名可能带反引号,所以检查两种情况
587        let has_agent_config =
588            tables.contains_key("agent_config") || tables.contains_key("`agent_config`");
589        let has_agent_component_config = tables.contains_key("agent_component_config")
590            || tables.contains_key("`agent_component_config`");
591
592        assert!(has_agent_config, "应该包含 agent_config 表");
593        assert!(
594            has_agent_component_config,
595            "应该包含 agent_component_config 表"
596        );
597
598        println!("✅ CREATE TABLE 语句拼接测试通过");
599    }
600}