Skip to main content

client_core/sql_diff/
parser.rs

1use super::types::{TableColumn, TableDefinition, TableIndex};
2use crate::error::DuckError;
3use regex::Regex;
4use sqlparser::ast::{ColumnDef, DataType, Statement, TableConstraint};
5use sqlparser::dialect::MySqlDialect;
6use sqlparser::parser::Parser;
7use std::collections::HashMap;
8use tracing::{debug, info, warn};
9
10/// 移除标识符中的反引号
11#[inline]
12fn strip_backticks(s: &str) -> String {
13    s.trim_matches('`').to_string()
14}
15
16/// 将 sqlparser 的标识符转换为字符串(移除反引号)
17#[inline]
18fn ident_to_string<T: ToString>(ident: &T) -> String {
19    strip_backticks(&ident.to_string())
20}
21
22/// 解析SQL文件中的表结构
23pub fn parse_sql_tables(sql_content: &str) -> Result<HashMap<String, TableDefinition>, DuckError> {
24    let mut tables = HashMap::new();
25
26    // 使用正则表达式找到 USE 语句的位置,然后从该位置开始解析后续的 CREATE TABLE 语句
27    let create_table_statements = extract_create_table_statements_with_regex(sql_content)?;
28
29    let dialect = MySqlDialect {};
30
31    for create_table_sql in create_table_statements {
32        debug!("Parsing CREATE TABLE statement: {}", create_table_sql);
33
34        match Parser::parse_sql(&dialect, &create_table_sql) {
35            Ok(statements) => {
36                for statement in statements {
37                    if let Statement::CreateTable(create_table) = statement {
38                        // 移除表名中的反引号,确保表名统一
39                        let table_name = ident_to_string(&create_table.name);
40                        debug!("Parsing table: {}", table_name);
41
42                        let mut table_columns = Vec::new();
43                        let mut table_indexes = Vec::new();
44                        let mut primary_key_columns = Vec::new();
45
46                        // 解析列定义
47                        for column in &create_table.columns {
48                            let column_def = parse_column_definition(column)?;
49
50                            // 检查是否是列级别的主键
51                            if is_column_primary_key(column) {
52                                primary_key_columns.push(ident_to_string(&column.name));
53                            }
54
55                            table_columns.push(column_def);
56                        }
57
58                        // 如果有列级别的主键,添加到索引列表
59                        if !primary_key_columns.is_empty() {
60                            table_indexes.push(TableIndex {
61                                name: "PRIMARY".to_string(),
62                                columns: primary_key_columns,
63                                is_primary: true,
64                                is_unique: true,
65                                index_type: Some("PRIMARY".to_string()),
66                            });
67                        }
68
69                        // 解析约束(包括索引)
70                        for constraint in &create_table.constraints {
71                            if let Some(index) = parse_table_constraint(constraint)? {
72                                table_indexes.push(index);
73                            }
74                        }
75
76                        let table_def = TableDefinition {
77                            name: table_name.clone(),
78                            columns: table_columns,
79                            indexes: table_indexes,
80                            engine: None,  // 可以从原始SQL字符串中提取
81                            charset: None, // 可以从原始SQL字符串中提取
82                        };
83
84                        tables.insert(table_name, table_def);
85                    }
86                }
87            }
88            Err(e) => {
89                warn!(
90                    "Failed to parse SQL statement: {} - error: {}",
91                    create_table_sql, e
92                );
93            }
94        }
95    }
96
97    // 🔧 新增:解析独立的 CREATE INDEX 语句
98    parse_standalone_indexes(sql_content, &mut tables)?;
99
100    info!("Successfully parsed {} tables", tables.len());
101    Ok(tables)
102}
103
104/// 使用正则表达式找到 USE 语句位置,然后提取后续的 CREATE TABLE 语句
105fn extract_create_table_statements_with_regex(sql_content: &str) -> Result<Vec<String>, DuckError> {
106    // 创建正则表达式来匹配 USE 语句
107    let use_regex = Regex::new(r"(?i)^\s*USE\s+[^;]+;\s*$")
108        .map_err(|e| DuckError::custom(format!("正则表达式编译失败: {e}")))?;
109
110    let lines: Vec<&str> = sql_content.lines().collect();
111    let mut start_parsing_from_line = 0;
112
113    // 查找 USE 语句
114    for (line_idx, line) in lines.iter().enumerate() {
115        if use_regex.is_match(line) {
116            debug!("Found USE statement at line {}: {}", line_idx + 1, line);
117            start_parsing_from_line = line_idx + 1; // 从下一行开始
118            break;
119        }
120    }
121
122    // 如果没有找到 USE 语句,从头开始解析
123    if start_parsing_from_line == 0 {
124        debug!("No USE statement found, parsing entire file from the beginning");
125    }
126
127    // 从指定位置开始提取内容
128    let content_to_parse = if start_parsing_from_line < lines.len() {
129        lines[start_parsing_from_line..].join("\n")
130    } else {
131        sql_content.to_string()
132    };
133
134    extract_create_table_statements_from_content(&content_to_parse)
135}
136
137/// 从指定内容中提取 CREATE TABLE 语句
138fn extract_create_table_statements_from_content(content: &str) -> Result<Vec<String>, DuckError> {
139    let mut statements = Vec::new();
140
141    // 创建正则表达式来匹配 CREATE TABLE 语句的开始
142    let create_table_regex = Regex::new(r"(?i)^\s*CREATE\s+TABLE")
143        .map_err(|e| DuckError::custom(format!("正则表达式编译失败: {e}")))?;
144
145    let lines: Vec<&str> = content.lines().collect();
146    let mut current_statement = String::new();
147    let mut in_create_table = false;
148    let mut paren_count = 0;
149    let mut in_string = false;
150    let mut escape_next = false;
151
152    for line in lines {
153        let trimmed = line.trim();
154
155        // 跳过空行和注释
156        if trimmed.is_empty() || trimmed.starts_with("--") || trimmed.starts_with("/*") {
157            continue;
158        }
159
160        // 检查是否是 CREATE TABLE 语句的开始
161        if !in_create_table && create_table_regex.is_match(line) {
162            in_create_table = true;
163            current_statement.clear();
164            paren_count = 0;
165            in_string = false;
166            escape_next = false;
167        }
168
169        if in_create_table {
170            current_statement.push_str(line);
171            current_statement.push('\n');
172
173            // 逐字符分析以正确处理括号平衡
174            for ch in line.chars() {
175                if escape_next {
176                    escape_next = false;
177                    continue;
178                }
179
180                match ch {
181                    '\\' if in_string => {
182                        escape_next = true;
183                    }
184                    '\'' | '"' | '`' => {
185                        in_string = !in_string;
186                    }
187                    '(' if !in_string => {
188                        paren_count += 1;
189                    }
190                    ')' if !in_string => {
191                        paren_count -= 1;
192                    }
193                    ';' if !in_string && paren_count <= 0 => {
194                        // 找到语句结束
195                        statements.push(current_statement.trim().to_string());
196                        current_statement.clear();
197                        in_create_table = false;
198                        paren_count = 0;
199                        break;
200                    }
201                    _ => {}
202                }
203            }
204        }
205    }
206
207    // 处理可能没有分号结尾的语句
208    if in_create_table && !current_statement.trim().is_empty() {
209        statements.push(current_statement.trim().to_string());
210    }
211
212    debug!("Extracted {} CREATE TABLE statements", statements.len());
213    Ok(statements)
214}
215
216/// 解析列定义
217fn parse_column_definition(column: &ColumnDef) -> Result<TableColumn, DuckError> {
218    let column_name = ident_to_string(&column.name);
219    let data_type = format_data_type(&column.data_type);
220
221    let mut nullable = true;
222    let mut default_value = None;
223    let mut comment = None;
224    let mut auto_increment = false;
225
226    // 检查列选项
227    for option in &column.options {
228        match &option.option {
229            sqlparser::ast::ColumnOption::NotNull => {
230                nullable = false;
231            }
232            sqlparser::ast::ColumnOption::Default(expr) => {
233                default_value = Some(format_default_value(expr));
234            }
235            sqlparser::ast::ColumnOption::Comment(c) => {
236                comment = Some(c.clone());
237            }
238            sqlparser::ast::ColumnOption::Unique(_) => {
239                // sqlparser 0.62: 列级 UNIQUE 不再包含 is_primary 字段
240                // 列级 PRIMARY KEY 现在是独立的 ColumnOption::PrimaryKey 变体
241            }
242            sqlparser::ast::ColumnOption::PrimaryKey(_) => {
243                nullable = false; // 主键不能为空
244            }
245            sqlparser::ast::ColumnOption::DialectSpecific(tokens) => {
246                // 检查是否是AUTO_INCREMENT
247                let token_str = tokens
248                    .iter()
249                    .map(|t| t.to_string())
250                    .collect::<Vec<_>>()
251                    .join(" ")
252                    .to_uppercase();
253                if token_str.contains("AUTO_INCREMENT") {
254                    auto_increment = true;
255                }
256            }
257            _ => {}
258        }
259    }
260
261    Ok(TableColumn {
262        name: column_name,
263        data_type,
264        nullable,
265        default_value,
266        auto_increment,
267        comment,
268    })
269}
270
271/// 解析表约束
272fn parse_table_constraint(constraint: &TableConstraint) -> Result<Option<TableIndex>, DuckError> {
273    match constraint {
274        TableConstraint::PrimaryKey(pk) => {
275            let column_names = extract_index_columns(&pk.columns);
276
277            Ok(Some(TableIndex {
278                name: "PRIMARY".to_string(),
279                columns: column_names,
280                is_primary: true,
281                is_unique: true,
282                index_type: Some("PRIMARY".to_string()),
283            }))
284        }
285        TableConstraint::Unique(uq) => {
286            let column_names = extract_index_columns(&uq.columns);
287            let index_name = uq
288                .name
289                .as_ref()
290                .map(ident_to_string)
291                .unwrap_or_else(|| format!("unique_{}", column_names.join("_")));
292
293            Ok(Some(TableIndex {
294                name: index_name,
295                columns: column_names,
296                is_primary: false,
297                is_unique: true,
298                index_type: Some("UNIQUE".to_string()),
299            }))
300        }
301        TableConstraint::Index(idx) => {
302            let column_names = extract_index_columns(&idx.columns);
303            let index_name = idx
304                .name
305                .as_ref()
306                .map(ident_to_string)
307                .unwrap_or_else(|| format!("idx_{}", column_names.join("_")));
308
309            Ok(Some(TableIndex {
310                name: index_name,
311                columns: column_names,
312                is_primary: false,
313                is_unique: false,
314                index_type: Some("INDEX".to_string()),
315            }))
316        }
317        _ => Ok(None),
318    }
319}
320
321/// 格式化默认值(特别处理函数类型的默认值)
322fn format_default_value(expr: &sqlparser::ast::Expr) -> String {
323    debug!("format_default_value called, expression: {:?}", expr);
324
325    match expr {
326        // 处理函数调用,如 CURRENT_TIMESTAMP
327        sqlparser::ast::Expr::Function(function) => {
328            let function_name = function.name.to_string();
329            debug!("Detected function call: {}", function_name);
330            // 对于 MySQL 的日期时间函数,不需要加引号,直接返回函数名
331            match function_name.to_uppercase().as_str() {
332                "CURRENT_TIMESTAMP" | "NOW" | "CURRENT_DATE" | "CURRENT_TIME"
333                | "LOCALTIMESTAMP" | "LOCALTIME" => {
334                    debug!(
335                        "Recognized as MySQL datetime function, returning: {}",
336                        function_name
337                    );
338                    function_name
339                }
340                _ => {
341                    debug!("Other function, using default format: {}", function_name);
342                    // 其他函数保持原有格式
343                    format!("{expr}")
344                }
345            }
346        }
347
348        // 处理各种值类型
349        sqlparser::ast::Expr::Value(value_with_span) => {
350            debug!("Detected value type: {:?}", value_with_span);
351            match &value_with_span.value {
352                sqlparser::ast::Value::SingleQuotedString(s) => {
353                    debug!("String value: {} -> '{}'", s, s);
354                    format!("'{}'", s)
355                }
356                sqlparser::ast::Value::Number(_, _) => {
357                    debug!("Numeric value");
358                    // 数字类型不需要引号,直接返回表达式格式化结果
359                    format!("{expr}")
360                }
361                sqlparser::ast::Value::Null => {
362                    debug!("NULL value");
363                    "NULL".to_string()
364                }
365                sqlparser::ast::Value::Boolean(b) => {
366                    debug!("Boolean value: {}", b);
367                    b.to_string()
368                }
369                // 处理其他值类型
370                _ => {
371                    debug!("Other value type");
372                    format!("{expr}")
373                }
374            }
375        }
376
377        // 其他情况使用默认格式化
378        _ => {
379            debug!("Other expression type");
380            format!("{expr}")
381        }
382    }
383}
384
385/// 格式化数据类型
386fn format_data_type(data_type: &DataType) -> String {
387    match data_type {
388        DataType::Char(size) => {
389            if let Some(size) = size {
390                format!("CHAR({size})")
391            } else {
392                "CHAR".to_string()
393            }
394        }
395        DataType::Varchar(size) => {
396            if let Some(size) = size {
397                format!("VARCHAR({size})")
398            } else {
399                "VARCHAR".to_string()
400            }
401        }
402        DataType::Text => "TEXT".to_string(),
403        DataType::Int(_) => "INT".to_string(),
404        DataType::BigInt(_) => "BIGINT".to_string(),
405        DataType::TinyInt(_) => "TINYINT".to_string(),
406        DataType::SmallInt(_) => "SMALLINT".to_string(),
407        DataType::MediumInt(_) => "MEDIUMINT".to_string(),
408        DataType::Float(_) => "FLOAT".to_string(),
409        DataType::Double(_) => "DOUBLE".to_string(),
410        DataType::Decimal(exact_number_info) => match exact_number_info {
411            sqlparser::ast::ExactNumberInfo::PrecisionAndScale(precision, scale) => {
412                format!("DECIMAL({precision},{scale})")
413            }
414            sqlparser::ast::ExactNumberInfo::Precision(precision) => {
415                format!("DECIMAL({precision})")
416            }
417            sqlparser::ast::ExactNumberInfo::None => "DECIMAL".to_string(),
418        },
419        DataType::Boolean => "BOOLEAN".to_string(),
420        DataType::Date => "DATE".to_string(),
421        DataType::Time(_, _) => "TIME".to_string(),
422        DataType::Timestamp(_, _) => "TIMESTAMP".to_string(),
423        DataType::Datetime(_) => "DATETIME".to_string(),
424        DataType::JSON => "JSON".to_string(),
425        DataType::Enum(variants, _max_length) => {
426            // 正确处理 ENUM 变体
427            let enum_values: Vec<String> = variants
428                .iter()
429                .map(|variant| match variant {
430                    sqlparser::ast::EnumMember::Name(name) => format!("'{}'", name),
431                    sqlparser::ast::EnumMember::NamedValue(name, _expr) => {
432                        format!("'{}'", name)
433                    }
434                })
435                .collect();
436
437            if enum_values.is_empty() {
438                "ENUM()".to_string()
439            } else {
440                format!("ENUM({})", enum_values.join(","))
441            }
442        }
443        _ => format!("{data_type:?}"), // 对于其他类型,使用 Debug 格式
444    }
445}
446
447/// 检查列是否是列级别的主键
448fn is_column_primary_key(column: &ColumnDef) -> bool {
449    for option in &column.options {
450        if let sqlparser::ast::ColumnOption::PrimaryKey(_) = &option.option {
451            return true;
452        }
453    }
454    false
455}
456
457/// 从 IndexColumn 列表中提取列名
458///
459/// 处理三种情况:
460/// 1. 简单列名:`column_name`
461/// 2. 复合标识符:`table.column` (只取最后一部分)
462/// 3. 复杂表达式:函数索引等 (使用 Display)
463fn extract_index_columns(index_columns: &[sqlparser::ast::IndexColumn]) -> Vec<String> {
464    index_columns
465        .iter()
466        .filter_map(|index_col| {
467            match &index_col.column.expr {
468                sqlparser::ast::Expr::Identifier(ident) => Some(strip_backticks(&ident.value)),
469                sqlparser::ast::Expr::CompoundIdentifier(idents) => {
470                    // 处理 table.column 格式,只取最后一个部分
471                    idents.last().map(|id| strip_backticks(&id.value))
472                }
473                _ => {
474                    // 对于函数索引等复杂表达式,使用 Display
475                    Some(strip_backticks(&index_col.column.to_string()))
476                }
477            }
478        })
479        .collect()
480}
481
482/// 解析独立的 CREATE INDEX 语句并添加到表定义中
483///
484/// 使用 sqlparser 库正确解析 SQL 语法
485///
486/// 格式示例:
487/// ```sql
488/// create index idx_space_id
489///     on agent_config (space_id);
490///
491/// create unique index uk_name
492///     on users (username);
493/// ```
494fn parse_standalone_indexes(
495    sql_content: &str,
496    tables: &mut HashMap<String, TableDefinition>,
497) -> Result<(), DuckError> {
498    let dialect = MySqlDialect {};
499    let mut index_count = 0;
500
501    // 提取所有 CREATE INDEX 语句
502    let index_statements = extract_create_index_statements(sql_content)?;
503
504    for index_sql in index_statements {
505        debug!("Parsing CREATE INDEX statement: {}", index_sql);
506
507        match Parser::parse_sql(&dialect, &index_sql) {
508            Ok(statements) => {
509                for statement in statements {
510                    if let Statement::CreateIndex(create_index) = statement {
511                        // 提取索引名称
512                        let index_name = create_index
513                            .name
514                            .as_ref()
515                            .map(ident_to_string)
516                            .unwrap_or_else(|| "unnamed_index".to_string());
517
518                        // 提取表名
519                        let table_name = ident_to_string(&create_index.table_name);
520
521                        // 提取列名列表
522                        let columns = extract_index_columns(&create_index.columns);
523
524                        if columns.is_empty() {
525                            warn!("Index {} has no column definition, skipping", index_name);
526                            continue;
527                        }
528
529                        // 检查是否是 UNIQUE 索引
530                        let is_unique = create_index.unique;
531
532                        // 查找对应的表
533                        if let Some(table_def) = tables.get_mut(&table_name) {
534                            // 检查是否已经存在同名索引
535                            if table_def.indexes.iter().any(|idx| idx.name == index_name) {
536                                debug!(
537                                    "Index {} already exists in table {}, skipping",
538                                    index_name, table_name
539                                );
540                                continue;
541                            }
542
543                            // 添加索引到表定义
544                            table_def.indexes.push(TableIndex {
545                                name: index_name.clone(),
546                                columns: columns.clone(),
547                                is_primary: false,
548                                is_unique,
549                                index_type: if is_unique {
550                                    Some("UNIQUE".to_string())
551                                } else {
552                                    Some("INDEX".to_string())
553                                },
554                            });
555
556                            index_count += 1;
557                            debug!(
558                                "添加独立索引: {} 到表 {} (列: {:?}, unique: {})",
559                                index_name, table_name, columns, is_unique
560                            );
561                        } else {
562                            warn!(
563                                "Index {} references table {} which does not exist, skipping",
564                                index_name, table_name
565                            );
566                        }
567                    }
568                }
569            }
570            Err(e) => {
571                warn!(
572                    "Failed to parse CREATE INDEX statement: {} - error: {}",
573                    index_sql, e
574                );
575            }
576        }
577    }
578
579    if index_count > 0 {
580        info!(
581            "Successfully parsed {} standalone CREATE INDEX statements",
582            index_count
583        );
584    }
585
586    Ok(())
587}
588
589/// 提取所有 CREATE INDEX 语句
590///
591/// 使用简单的状态机来识别完整的 CREATE INDEX 语句
592fn extract_create_index_statements(sql_content: &str) -> Result<Vec<String>, DuckError> {
593    let mut statements = Vec::new();
594    let mut current_statement = String::new();
595    let mut in_create_index = false;
596
597    // 正则表达式只用于识别语句开始,不用于解析
598    let create_index_regex = Regex::new(r"(?i)^\s*CREATE\s+(UNIQUE\s+)?INDEX")
599        .map_err(|e| DuckError::custom(format!("正则表达式编译失败: {}", e)))?;
600
601    for line in sql_content.lines() {
602        let trimmed = line.trim();
603
604        // 跳过空行和注释
605        if trimmed.is_empty() || trimmed.starts_with("--") {
606            continue;
607        }
608
609        // 检查是否是 CREATE INDEX 语句的开始
610        if !in_create_index && create_index_regex.is_match(line) {
611            in_create_index = true;
612            current_statement.clear();
613        }
614
615        if in_create_index {
616            current_statement.push_str(line);
617            current_statement.push(' ');
618
619            // 检查是否遇到分号(语句结束)
620            if trimmed.ends_with(';') {
621                statements.push(current_statement.trim().to_string());
622                current_statement.clear();
623                in_create_index = false;
624            }
625        }
626    }
627
628    // 处理可能没有分号结尾的语句
629    if in_create_index && !current_statement.trim().is_empty() {
630        statements.push(current_statement.trim().to_string());
631    }
632
633    debug!("Extracted {} CREATE INDEX statements", statements.len());
634    Ok(statements)
635}