Skip to main content

client_core/sql_diff/
parser.rs

1use super::types::{TableColumn, TableDefinition, TableIndex};
2use crate::error::DuckError;
3use regex::Regex;
4use sqlparser::ast::{ColumnDef, DataType, Statement, TableConstraint};
5use sqlparser::dialect::MySqlDialect;
6use sqlparser::parser::Parser;
7use std::collections::HashMap;
8use tracing::{debug, info, warn};
9
10/// 移除标识符中的反引号
11#[inline]
12fn strip_backticks(s: &str) -> String {
13    s.trim_matches('`').to_string()
14}
15
16/// 将 sqlparser 的标识符转换为字符串(移除反引号)
17#[inline]
18fn ident_to_string<T: ToString>(ident: &T) -> String {
19    strip_backticks(&ident.to_string())
20}
21
22/// 解析SQL文件中的表结构
23pub fn parse_sql_tables(sql_content: &str) -> Result<HashMap<String, TableDefinition>, DuckError> {
24    let mut tables = HashMap::new();
25
26    // 使用正则表达式找到 USE 语句的位置,然后从该位置开始解析后续的 CREATE TABLE 语句
27    let create_table_statements = extract_create_table_statements_with_regex(sql_content)?;
28
29    let dialect = MySqlDialect {};
30
31    for create_table_sql in create_table_statements {
32        debug!("Parsing CREATE TABLE statement: {}", create_table_sql);
33
34        match Parser::parse_sql(&dialect, &create_table_sql) {
35            Ok(statements) => {
36                for statement in statements {
37                    if let Statement::CreateTable(create_table) = statement {
38                        // 移除表名中的反引号,确保表名统一
39                        let table_name = ident_to_string(&create_table.name);
40                        debug!("Parsing table: {}", table_name);
41
42                        let mut table_columns = Vec::new();
43                        let mut table_indexes = Vec::new();
44                        let mut primary_key_columns = Vec::new();
45
46                        // 解析列定义
47                        for column in &create_table.columns {
48                            let column_def = parse_column_definition(column)?;
49
50                            // 检查是否是列级别的主键
51                            if is_column_primary_key(column) {
52                                primary_key_columns.push(ident_to_string(&column.name));
53                            }
54
55                            table_columns.push(column_def);
56                        }
57
58                        // 如果有列级别的主键,添加到索引列表
59                        if !primary_key_columns.is_empty() {
60                            table_indexes.push(TableIndex {
61                                name: "PRIMARY".to_string(),
62                                columns: primary_key_columns,
63                                is_primary: true,
64                                is_unique: true,
65                                index_type: Some("PRIMARY".to_string()),
66                            });
67                        }
68
69                        // 解析约束(包括索引)
70                        for constraint in &create_table.constraints {
71                            if let Some(index) = parse_table_constraint(constraint)? {
72                                table_indexes.push(index);
73                            }
74                        }
75
76                        let table_def = TableDefinition {
77                            name: table_name.clone(),
78                            columns: table_columns,
79                            indexes: table_indexes,
80                            engine: None,  // 可以从原始SQL字符串中提取
81                            charset: None, // 可以从原始SQL字符串中提取
82                        };
83
84                        tables.insert(table_name, table_def);
85                    }
86                }
87            }
88            Err(e) => {
89                warn!("Failed to parse SQL statement: {} - error: {}", create_table_sql, e);
90            }
91        }
92    }
93
94    // 🔧 新增:解析独立的 CREATE INDEX 语句
95    parse_standalone_indexes(sql_content, &mut tables)?;
96
97    info!("Successfully parsed {} tables", tables.len());
98    Ok(tables)
99}
100
101/// 使用正则表达式找到 USE 语句位置,然后提取后续的 CREATE TABLE 语句
102fn extract_create_table_statements_with_regex(sql_content: &str) -> Result<Vec<String>, DuckError> {
103    // 创建正则表达式来匹配 USE 语句
104    let use_regex = Regex::new(r"(?i)^\s*USE\s+[^;]+;\s*$")
105        .map_err(|e| DuckError::custom(format!("正则表达式编译失败: {e}")))?;
106
107    let lines: Vec<&str> = sql_content.lines().collect();
108    let mut start_parsing_from_line = 0;
109
110    // 查找 USE 语句
111    for (line_idx, line) in lines.iter().enumerate() {
112        if use_regex.is_match(line) {
113            debug!("Found USE statement at line {}: {}", line_idx + 1, line);
114            start_parsing_from_line = line_idx + 1; // 从下一行开始
115            break;
116        }
117    }
118
119    // 如果没有找到 USE 语句,从头开始解析
120    if start_parsing_from_line == 0 {
121        debug!("No USE statement found, parsing entire file from the beginning");
122    }
123
124    // 从指定位置开始提取内容
125    let content_to_parse = if start_parsing_from_line < lines.len() {
126        lines[start_parsing_from_line..].join("\n")
127    } else {
128        sql_content.to_string()
129    };
130
131    extract_create_table_statements_from_content(&content_to_parse)
132}
133
134/// 从指定内容中提取 CREATE TABLE 语句
135fn extract_create_table_statements_from_content(content: &str) -> Result<Vec<String>, DuckError> {
136    let mut statements = Vec::new();
137
138    // 创建正则表达式来匹配 CREATE TABLE 语句的开始
139    let create_table_regex = Regex::new(r"(?i)^\s*CREATE\s+TABLE")
140        .map_err(|e| DuckError::custom(format!("正则表达式编译失败: {e}")))?;
141
142    let lines: Vec<&str> = content.lines().collect();
143    let mut current_statement = String::new();
144    let mut in_create_table = false;
145    let mut paren_count = 0;
146    let mut in_string = false;
147    let mut escape_next = false;
148
149    for line in lines {
150        let trimmed = line.trim();
151
152        // 跳过空行和注释
153        if trimmed.is_empty() || trimmed.starts_with("--") || trimmed.starts_with("/*") {
154            continue;
155        }
156
157        // 检查是否是 CREATE TABLE 语句的开始
158        if !in_create_table && create_table_regex.is_match(line) {
159            in_create_table = true;
160            current_statement.clear();
161            paren_count = 0;
162            in_string = false;
163            escape_next = false;
164        }
165
166        if in_create_table {
167            current_statement.push_str(line);
168            current_statement.push('\n');
169
170            // 逐字符分析以正确处理括号平衡
171            for ch in line.chars() {
172                if escape_next {
173                    escape_next = false;
174                    continue;
175                }
176
177                match ch {
178                    '\\' if in_string => {
179                        escape_next = true;
180                    }
181                    '\'' | '"' | '`' => {
182                        in_string = !in_string;
183                    }
184                    '(' if !in_string => {
185                        paren_count += 1;
186                    }
187                    ')' if !in_string => {
188                        paren_count -= 1;
189                    }
190                    ';' if !in_string && paren_count <= 0 => {
191                        // 找到语句结束
192                        statements.push(current_statement.trim().to_string());
193                        current_statement.clear();
194                        in_create_table = false;
195                        paren_count = 0;
196                        break;
197                    }
198                    _ => {}
199                }
200            }
201        }
202    }
203
204    // 处理可能没有分号结尾的语句
205    if in_create_table && !current_statement.trim().is_empty() {
206        statements.push(current_statement.trim().to_string());
207    }
208
209    debug!("Extracted {} CREATE TABLE statements", statements.len());
210    Ok(statements)
211}
212
213/// 解析列定义
214fn parse_column_definition(column: &ColumnDef) -> Result<TableColumn, DuckError> {
215    let column_name = ident_to_string(&column.name);
216    let data_type = format_data_type(&column.data_type);
217
218    let mut nullable = true;
219    let mut default_value = None;
220    let mut comment = None;
221    let mut auto_increment = false;
222
223    // 检查列选项
224    for option in &column.options {
225        match &option.option {
226            sqlparser::ast::ColumnOption::NotNull => {
227                nullable = false;
228            }
229            sqlparser::ast::ColumnOption::Default(expr) => {
230                default_value = Some(format_default_value(expr));
231            }
232            sqlparser::ast::ColumnOption::Comment(c) => {
233                comment = Some(c.clone());
234            }
235            sqlparser::ast::ColumnOption::Unique { is_primary, .. } => {
236                if *is_primary {
237                    nullable = false; // 主键不能为空
238                }
239            }
240            sqlparser::ast::ColumnOption::DialectSpecific(tokens) => {
241                // 检查是否是AUTO_INCREMENT
242                let token_str = tokens
243                    .iter()
244                    .map(|t| t.to_string())
245                    .collect::<Vec<_>>()
246                    .join(" ")
247                    .to_uppercase();
248                if token_str.contains("AUTO_INCREMENT") {
249                    auto_increment = true;
250                }
251            }
252            _ => {}
253        }
254    }
255
256    Ok(TableColumn {
257        name: column_name,
258        data_type,
259        nullable,
260        default_value,
261        auto_increment,
262        comment,
263    })
264}
265
266/// 解析表约束
267fn parse_table_constraint(constraint: &TableConstraint) -> Result<Option<TableIndex>, DuckError> {
268    match constraint {
269        TableConstraint::PrimaryKey { columns, .. } => {
270            let column_names: Vec<String> = columns.iter().map(ident_to_string).collect();
271
272            Ok(Some(TableIndex {
273                name: "PRIMARY".to_string(),
274                columns: column_names,
275                is_primary: true,
276                is_unique: true,
277                index_type: Some("PRIMARY".to_string()),
278            }))
279        }
280        TableConstraint::Unique { columns, name, .. } => {
281            let column_names: Vec<String> = columns.iter().map(ident_to_string).collect();
282            let index_name = name
283                .as_ref()
284                .map(ident_to_string)
285                .unwrap_or_else(|| format!("unique_{}", column_names.join("_")));
286
287            Ok(Some(TableIndex {
288                name: index_name,
289                columns: column_names,
290                is_primary: false,
291                is_unique: true,
292                index_type: Some("UNIQUE".to_string()),
293            }))
294        }
295        TableConstraint::Index { name, columns, .. } => {
296            let column_names: Vec<String> = columns.iter().map(ident_to_string).collect();
297            let index_name = name
298                .as_ref()
299                .map(ident_to_string)
300                .unwrap_or_else(|| format!("idx_{}", column_names.join("_")));
301
302            Ok(Some(TableIndex {
303                name: index_name,
304                columns: column_names,
305                is_primary: false,
306                is_unique: false,
307                index_type: Some("INDEX".to_string()),
308            }))
309        }
310        _ => Ok(None),
311    }
312}
313
314/// 格式化默认值(特别处理函数类型的默认值)
315fn format_default_value(expr: &sqlparser::ast::Expr) -> String {
316    debug!("format_default_value called, expression: {:?}", expr);
317
318    match expr {
319        // 处理函数调用,如 CURRENT_TIMESTAMP
320        sqlparser::ast::Expr::Function(function) => {
321            let function_name = function.name.to_string();
322            debug!("Detected function call: {}", function_name);
323            // 对于 MySQL 的日期时间函数,不需要加引号,直接返回函数名
324            match function_name.to_uppercase().as_str() {
325                "CURRENT_TIMESTAMP" | "NOW" | "CURRENT_DATE" | "CURRENT_TIME"
326                | "LOCALTIMESTAMP" | "LOCALTIME" => {
327                    debug!("Recognized as MySQL datetime function, returning: {}", function_name);
328                    function_name
329                }
330                _ => {
331                    debug!("Other function, using default format: {}", function_name);
332                    // 其他函数保持原有格式
333                    format!("{expr}")
334                }
335            }
336        }
337
338        // 处理各种值类型
339        sqlparser::ast::Expr::Value(value_with_span) => {
340            debug!("Detected value type: {:?}", value_with_span);
341            match &value_with_span.value {
342                sqlparser::ast::Value::SingleQuotedString(s) => {
343                    debug!("String value: {} -> '{}'", s, s);
344                    format!("'{}'", s)
345                }
346                sqlparser::ast::Value::Number(_, _) => {
347                    debug!("Numeric value");
348                    // 数字类型不需要引号,直接返回表达式格式化结果
349                    format!("{expr}")
350                }
351                sqlparser::ast::Value::Null => {
352                    debug!("NULL value");
353                    "NULL".to_string()
354                }
355                sqlparser::ast::Value::Boolean(b) => {
356                    debug!("Boolean value: {}", b);
357                    b.to_string()
358                }
359                // 处理其他值类型
360                _ => {
361                    debug!("Other value type");
362                    format!("{expr}")
363                }
364            }
365        }
366
367        // 其他情况使用默认格式化
368        _ => {
369            debug!("Other expression type");
370            format!("{expr}")
371        }
372    }
373}
374
375/// 格式化数据类型
376fn format_data_type(data_type: &DataType) -> String {
377    match data_type {
378        DataType::Char(size) => {
379            if let Some(size) = size {
380                format!("CHAR({size})")
381            } else {
382                "CHAR".to_string()
383            }
384        }
385        DataType::Varchar(size) => {
386            if let Some(size) = size {
387                format!("VARCHAR({size})")
388            } else {
389                "VARCHAR".to_string()
390            }
391        }
392        DataType::Text => "TEXT".to_string(),
393        DataType::Int(_) => "INT".to_string(),
394        DataType::BigInt(_) => "BIGINT".to_string(),
395        DataType::TinyInt(_) => "TINYINT".to_string(),
396        DataType::SmallInt(_) => "SMALLINT".to_string(),
397        DataType::MediumInt(_) => "MEDIUMINT".to_string(),
398        DataType::Float(_) => "FLOAT".to_string(),
399        DataType::Double(_) => "DOUBLE".to_string(),
400        DataType::Decimal(exact_number_info) => match exact_number_info {
401            sqlparser::ast::ExactNumberInfo::PrecisionAndScale(precision, scale) => {
402                format!("DECIMAL({precision},{scale})")
403            }
404            sqlparser::ast::ExactNumberInfo::Precision(precision) => {
405                format!("DECIMAL({precision})")
406            }
407            sqlparser::ast::ExactNumberInfo::None => "DECIMAL".to_string(),
408        },
409        DataType::Boolean => "BOOLEAN".to_string(),
410        DataType::Date => "DATE".to_string(),
411        DataType::Time(_, _) => "TIME".to_string(),
412        DataType::Timestamp(_, _) => "TIMESTAMP".to_string(),
413        DataType::Datetime(_) => "DATETIME".to_string(),
414        DataType::JSON => "JSON".to_string(),
415        DataType::Enum(variants, _max_length) => {
416            // 正确处理 ENUM 变体
417            let enum_values: Vec<String> = variants
418                .iter()
419                .filter_map(|variant| match variant {
420                    sqlparser::ast::EnumMember::Name(name) => Some(format!("'{}'", name)),
421                    sqlparser::ast::EnumMember::NamedValue(name, _expr) => {
422                        Some(format!("'{}'", name))
423                    }
424                })
425                .collect();
426
427            if enum_values.is_empty() {
428                "ENUM()".to_string()
429            } else {
430                format!("ENUM({})", enum_values.join(","))
431            }
432        }
433        _ => format!("{data_type:?}"), // 对于其他类型,使用 Debug 格式
434    }
435}
436
437/// 检查列是否是列级别的主键
438fn is_column_primary_key(column: &ColumnDef) -> bool {
439    for option in &column.options {
440        if let sqlparser::ast::ColumnOption::Unique { is_primary, .. } = &option.option {
441            if *is_primary {
442                return true;
443            }
444        }
445    }
446    false
447}
448
449/// 从 IndexColumn 列表中提取列名
450///
451/// 处理三种情况:
452/// 1. 简单列名:`column_name`
453/// 2. 复合标识符:`table.column` (只取最后一部分)
454/// 3. 复杂表达式:函数索引等 (使用 Display)
455fn extract_index_columns(index_columns: &[sqlparser::ast::IndexColumn]) -> Vec<String> {
456    index_columns
457        .iter()
458        .filter_map(|index_col| {
459            match &index_col.column.expr {
460                sqlparser::ast::Expr::Identifier(ident) => Some(strip_backticks(&ident.value)),
461                sqlparser::ast::Expr::CompoundIdentifier(idents) => {
462                    // 处理 table.column 格式,只取最后一个部分
463                    idents.last().map(|id| strip_backticks(&id.value))
464                }
465                _ => {
466                    // 对于函数索引等复杂表达式,使用 Display
467                    Some(strip_backticks(&index_col.column.to_string()))
468                }
469            }
470        })
471        .collect()
472}
473
474/// 解析独立的 CREATE INDEX 语句并添加到表定义中
475///
476/// 使用 sqlparser 库正确解析 SQL 语法
477///
478/// 格式示例:
479/// ```sql
480/// create index idx_space_id
481///     on agent_config (space_id);
482///
483/// create unique index uk_name
484///     on users (username);
485/// ```
486fn parse_standalone_indexes(
487    sql_content: &str,
488    tables: &mut HashMap<String, TableDefinition>,
489) -> Result<(), DuckError> {
490    let dialect = MySqlDialect {};
491    let mut index_count = 0;
492
493    // 提取所有 CREATE INDEX 语句
494    let index_statements = extract_create_index_statements(sql_content)?;
495
496    for index_sql in index_statements {
497        debug!("Parsing CREATE INDEX statement: {}", index_sql);
498
499        match Parser::parse_sql(&dialect, &index_sql) {
500            Ok(statements) => {
501                for statement in statements {
502                    if let Statement::CreateIndex(create_index) = statement {
503                        // 提取索引名称
504                        let index_name = create_index
505                            .name
506                            .as_ref()
507                            .map(ident_to_string)
508                            .unwrap_or_else(|| "unnamed_index".to_string());
509
510                        // 提取表名
511                        let table_name = ident_to_string(&create_index.table_name);
512
513                        // 提取列名列表
514                        let columns = extract_index_columns(&create_index.columns);
515
516                        if columns.is_empty() {
517                            warn!("Index {} has no column definition, skipping", index_name);
518                            continue;
519                        }
520
521                        // 检查是否是 UNIQUE 索引
522                        let is_unique = create_index.unique;
523
524                        // 查找对应的表
525                        if let Some(table_def) = tables.get_mut(&table_name) {
526                            // 检查是否已经存在同名索引
527                            if table_def.indexes.iter().any(|idx| idx.name == index_name) {
528                                debug!("Index {} already exists in table {}, skipping", index_name, table_name);
529                                continue;
530                            }
531
532                            // 添加索引到表定义
533                            table_def.indexes.push(TableIndex {
534                                name: index_name.clone(),
535                                columns: columns.clone(),
536                                is_primary: false,
537                                is_unique,
538                                index_type: if is_unique {
539                                    Some("UNIQUE".to_string())
540                                } else {
541                                    Some("INDEX".to_string())
542                                },
543                            });
544
545                            index_count += 1;
546                            debug!(
547                                "添加独立索引: {} 到表 {} (列: {:?}, unique: {})",
548                                index_name, table_name, columns, is_unique
549                            );
550                        } else {
551                            warn!("Index {} references table {} which does not exist, skipping", index_name, table_name);
552                        }
553                    }
554                }
555            }
556            Err(e) => {
557                warn!("Failed to parse CREATE INDEX statement: {} - error: {}", index_sql, e);
558            }
559        }
560    }
561
562    if index_count > 0 {
563        info!("Successfully parsed {} standalone CREATE INDEX statements", index_count);
564    }
565
566    Ok(())
567}
568
569/// 提取所有 CREATE INDEX 语句
570///
571/// 使用简单的状态机来识别完整的 CREATE INDEX 语句
572fn extract_create_index_statements(sql_content: &str) -> Result<Vec<String>, DuckError> {
573    let mut statements = Vec::new();
574    let mut current_statement = String::new();
575    let mut in_create_index = false;
576
577    // 正则表达式只用于识别语句开始,不用于解析
578    let create_index_regex = Regex::new(r"(?i)^\s*CREATE\s+(UNIQUE\s+)?INDEX")
579        .map_err(|e| DuckError::custom(format!("正则表达式编译失败: {}", e)))?;
580
581    for line in sql_content.lines() {
582        let trimmed = line.trim();
583
584        // 跳过空行和注释
585        if trimmed.is_empty() || trimmed.starts_with("--") {
586            continue;
587        }
588
589        // 检查是否是 CREATE INDEX 语句的开始
590        if !in_create_index && create_index_regex.is_match(line) {
591            in_create_index = true;
592            current_statement.clear();
593        }
594
595        if in_create_index {
596            current_statement.push_str(line);
597            current_statement.push(' ');
598
599            // 检查是否遇到分号(语句结束)
600            if trimmed.ends_with(';') {
601                statements.push(current_statement.trim().to_string());
602                current_statement.clear();
603                in_create_index = false;
604            }
605        }
606    }
607
608    // 处理可能没有分号结尾的语句
609    if in_create_index && !current_statement.trim().is_empty() {
610        statements.push(current_statement.trim().to_string());
611    }
612
613    debug!("Extracted {} CREATE INDEX statements", statements.len());
614    Ok(statements)
615}