1use super::types::{TableColumn, TableDefinition, TableIndex};
2use crate::error::DuckError;
3use regex::Regex;
4use sqlparser::ast::{ColumnDef, DataType, Statement, TableConstraint};
5use sqlparser::dialect::MySqlDialect;
6use sqlparser::parser::Parser;
7use std::collections::HashMap;
8use tracing::{debug, info, warn};
9
10#[inline]
12fn strip_backticks(s: &str) -> String {
13 s.trim_matches('`').to_string()
14}
15
16#[inline]
18fn ident_to_string<T: ToString>(ident: &T) -> String {
19 strip_backticks(&ident.to_string())
20}
21
22pub fn parse_sql_tables(sql_content: &str) -> Result<HashMap<String, TableDefinition>, DuckError> {
24 let mut tables = HashMap::new();
25
26 let create_table_statements = extract_create_table_statements_with_regex(sql_content)?;
28
29 let dialect = MySqlDialect {};
30
31 for create_table_sql in create_table_statements {
32 debug!("Parsing CREATE TABLE statement: {}", create_table_sql);
33
34 match Parser::parse_sql(&dialect, &create_table_sql) {
35 Ok(statements) => {
36 for statement in statements {
37 if let Statement::CreateTable(create_table) = statement {
38 let table_name = ident_to_string(&create_table.name);
40 debug!("Parsing table: {}", table_name);
41
42 let mut table_columns = Vec::new();
43 let mut table_indexes = Vec::new();
44 let mut primary_key_columns = Vec::new();
45
46 for column in &create_table.columns {
48 let column_def = parse_column_definition(column)?;
49
50 if is_column_primary_key(column) {
52 primary_key_columns.push(ident_to_string(&column.name));
53 }
54
55 table_columns.push(column_def);
56 }
57
58 if !primary_key_columns.is_empty() {
60 table_indexes.push(TableIndex {
61 name: "PRIMARY".to_string(),
62 columns: primary_key_columns,
63 is_primary: true,
64 is_unique: true,
65 index_type: Some("PRIMARY".to_string()),
66 });
67 }
68
69 for constraint in &create_table.constraints {
71 if let Some(index) = parse_table_constraint(constraint)? {
72 table_indexes.push(index);
73 }
74 }
75
76 let table_def = TableDefinition {
77 name: table_name.clone(),
78 columns: table_columns,
79 indexes: table_indexes,
80 engine: None, charset: None, };
83
84 tables.insert(table_name, table_def);
85 }
86 }
87 }
88 Err(e) => {
89 warn!("Failed to parse SQL statement: {} - error: {}", create_table_sql, e);
90 }
91 }
92 }
93
94 parse_standalone_indexes(sql_content, &mut tables)?;
96
97 info!("Successfully parsed {} tables", tables.len());
98 Ok(tables)
99}
100
101fn extract_create_table_statements_with_regex(sql_content: &str) -> Result<Vec<String>, DuckError> {
103 let use_regex = Regex::new(r"(?i)^\s*USE\s+[^;]+;\s*$")
105 .map_err(|e| DuckError::custom(format!("正则表达式编译失败: {e}")))?;
106
107 let lines: Vec<&str> = sql_content.lines().collect();
108 let mut start_parsing_from_line = 0;
109
110 for (line_idx, line) in lines.iter().enumerate() {
112 if use_regex.is_match(line) {
113 debug!("Found USE statement at line {}: {}", line_idx + 1, line);
114 start_parsing_from_line = line_idx + 1; break;
116 }
117 }
118
119 if start_parsing_from_line == 0 {
121 debug!("No USE statement found, parsing entire file from the beginning");
122 }
123
124 let content_to_parse = if start_parsing_from_line < lines.len() {
126 lines[start_parsing_from_line..].join("\n")
127 } else {
128 sql_content.to_string()
129 };
130
131 extract_create_table_statements_from_content(&content_to_parse)
132}
133
134fn extract_create_table_statements_from_content(content: &str) -> Result<Vec<String>, DuckError> {
136 let mut statements = Vec::new();
137
138 let create_table_regex = Regex::new(r"(?i)^\s*CREATE\s+TABLE")
140 .map_err(|e| DuckError::custom(format!("正则表达式编译失败: {e}")))?;
141
142 let lines: Vec<&str> = content.lines().collect();
143 let mut current_statement = String::new();
144 let mut in_create_table = false;
145 let mut paren_count = 0;
146 let mut in_string = false;
147 let mut escape_next = false;
148
149 for line in lines {
150 let trimmed = line.trim();
151
152 if trimmed.is_empty() || trimmed.starts_with("--") || trimmed.starts_with("/*") {
154 continue;
155 }
156
157 if !in_create_table && create_table_regex.is_match(line) {
159 in_create_table = true;
160 current_statement.clear();
161 paren_count = 0;
162 in_string = false;
163 escape_next = false;
164 }
165
166 if in_create_table {
167 current_statement.push_str(line);
168 current_statement.push('\n');
169
170 for ch in line.chars() {
172 if escape_next {
173 escape_next = false;
174 continue;
175 }
176
177 match ch {
178 '\\' if in_string => {
179 escape_next = true;
180 }
181 '\'' | '"' | '`' => {
182 in_string = !in_string;
183 }
184 '(' if !in_string => {
185 paren_count += 1;
186 }
187 ')' if !in_string => {
188 paren_count -= 1;
189 }
190 ';' if !in_string && paren_count <= 0 => {
191 statements.push(current_statement.trim().to_string());
193 current_statement.clear();
194 in_create_table = false;
195 paren_count = 0;
196 break;
197 }
198 _ => {}
199 }
200 }
201 }
202 }
203
204 if in_create_table && !current_statement.trim().is_empty() {
206 statements.push(current_statement.trim().to_string());
207 }
208
209 debug!("Extracted {} CREATE TABLE statements", statements.len());
210 Ok(statements)
211}
212
213fn parse_column_definition(column: &ColumnDef) -> Result<TableColumn, DuckError> {
215 let column_name = ident_to_string(&column.name);
216 let data_type = format_data_type(&column.data_type);
217
218 let mut nullable = true;
219 let mut default_value = None;
220 let mut comment = None;
221 let mut auto_increment = false;
222
223 for option in &column.options {
225 match &option.option {
226 sqlparser::ast::ColumnOption::NotNull => {
227 nullable = false;
228 }
229 sqlparser::ast::ColumnOption::Default(expr) => {
230 default_value = Some(format_default_value(expr));
231 }
232 sqlparser::ast::ColumnOption::Comment(c) => {
233 comment = Some(c.clone());
234 }
235 sqlparser::ast::ColumnOption::Unique { is_primary, .. } => {
236 if *is_primary {
237 nullable = false; }
239 }
240 sqlparser::ast::ColumnOption::DialectSpecific(tokens) => {
241 let token_str = tokens
243 .iter()
244 .map(|t| t.to_string())
245 .collect::<Vec<_>>()
246 .join(" ")
247 .to_uppercase();
248 if token_str.contains("AUTO_INCREMENT") {
249 auto_increment = true;
250 }
251 }
252 _ => {}
253 }
254 }
255
256 Ok(TableColumn {
257 name: column_name,
258 data_type,
259 nullable,
260 default_value,
261 auto_increment,
262 comment,
263 })
264}
265
266fn parse_table_constraint(constraint: &TableConstraint) -> Result<Option<TableIndex>, DuckError> {
268 match constraint {
269 TableConstraint::PrimaryKey { columns, .. } => {
270 let column_names: Vec<String> = columns.iter().map(ident_to_string).collect();
271
272 Ok(Some(TableIndex {
273 name: "PRIMARY".to_string(),
274 columns: column_names,
275 is_primary: true,
276 is_unique: true,
277 index_type: Some("PRIMARY".to_string()),
278 }))
279 }
280 TableConstraint::Unique { columns, name, .. } => {
281 let column_names: Vec<String> = columns.iter().map(ident_to_string).collect();
282 let index_name = name
283 .as_ref()
284 .map(ident_to_string)
285 .unwrap_or_else(|| format!("unique_{}", column_names.join("_")));
286
287 Ok(Some(TableIndex {
288 name: index_name,
289 columns: column_names,
290 is_primary: false,
291 is_unique: true,
292 index_type: Some("UNIQUE".to_string()),
293 }))
294 }
295 TableConstraint::Index { name, columns, .. } => {
296 let column_names: Vec<String> = columns.iter().map(ident_to_string).collect();
297 let index_name = name
298 .as_ref()
299 .map(ident_to_string)
300 .unwrap_or_else(|| format!("idx_{}", column_names.join("_")));
301
302 Ok(Some(TableIndex {
303 name: index_name,
304 columns: column_names,
305 is_primary: false,
306 is_unique: false,
307 index_type: Some("INDEX".to_string()),
308 }))
309 }
310 _ => Ok(None),
311 }
312}
313
314fn format_default_value(expr: &sqlparser::ast::Expr) -> String {
316 debug!("format_default_value called, expression: {:?}", expr);
317
318 match expr {
319 sqlparser::ast::Expr::Function(function) => {
321 let function_name = function.name.to_string();
322 debug!("Detected function call: {}", function_name);
323 match function_name.to_uppercase().as_str() {
325 "CURRENT_TIMESTAMP" | "NOW" | "CURRENT_DATE" | "CURRENT_TIME"
326 | "LOCALTIMESTAMP" | "LOCALTIME" => {
327 debug!("Recognized as MySQL datetime function, returning: {}", function_name);
328 function_name
329 }
330 _ => {
331 debug!("Other function, using default format: {}", function_name);
332 format!("{expr}")
334 }
335 }
336 }
337
338 sqlparser::ast::Expr::Value(value_with_span) => {
340 debug!("Detected value type: {:?}", value_with_span);
341 match &value_with_span.value {
342 sqlparser::ast::Value::SingleQuotedString(s) => {
343 debug!("String value: {} -> '{}'", s, s);
344 format!("'{}'", s)
345 }
346 sqlparser::ast::Value::Number(_, _) => {
347 debug!("Numeric value");
348 format!("{expr}")
350 }
351 sqlparser::ast::Value::Null => {
352 debug!("NULL value");
353 "NULL".to_string()
354 }
355 sqlparser::ast::Value::Boolean(b) => {
356 debug!("Boolean value: {}", b);
357 b.to_string()
358 }
359 _ => {
361 debug!("Other value type");
362 format!("{expr}")
363 }
364 }
365 }
366
367 _ => {
369 debug!("Other expression type");
370 format!("{expr}")
371 }
372 }
373}
374
375fn format_data_type(data_type: &DataType) -> String {
377 match data_type {
378 DataType::Char(size) => {
379 if let Some(size) = size {
380 format!("CHAR({size})")
381 } else {
382 "CHAR".to_string()
383 }
384 }
385 DataType::Varchar(size) => {
386 if let Some(size) = size {
387 format!("VARCHAR({size})")
388 } else {
389 "VARCHAR".to_string()
390 }
391 }
392 DataType::Text => "TEXT".to_string(),
393 DataType::Int(_) => "INT".to_string(),
394 DataType::BigInt(_) => "BIGINT".to_string(),
395 DataType::TinyInt(_) => "TINYINT".to_string(),
396 DataType::SmallInt(_) => "SMALLINT".to_string(),
397 DataType::MediumInt(_) => "MEDIUMINT".to_string(),
398 DataType::Float(_) => "FLOAT".to_string(),
399 DataType::Double(_) => "DOUBLE".to_string(),
400 DataType::Decimal(exact_number_info) => match exact_number_info {
401 sqlparser::ast::ExactNumberInfo::PrecisionAndScale(precision, scale) => {
402 format!("DECIMAL({precision},{scale})")
403 }
404 sqlparser::ast::ExactNumberInfo::Precision(precision) => {
405 format!("DECIMAL({precision})")
406 }
407 sqlparser::ast::ExactNumberInfo::None => "DECIMAL".to_string(),
408 },
409 DataType::Boolean => "BOOLEAN".to_string(),
410 DataType::Date => "DATE".to_string(),
411 DataType::Time(_, _) => "TIME".to_string(),
412 DataType::Timestamp(_, _) => "TIMESTAMP".to_string(),
413 DataType::Datetime(_) => "DATETIME".to_string(),
414 DataType::JSON => "JSON".to_string(),
415 DataType::Enum(variants, _max_length) => {
416 let enum_values: Vec<String> = variants
418 .iter()
419 .filter_map(|variant| match variant {
420 sqlparser::ast::EnumMember::Name(name) => Some(format!("'{}'", name)),
421 sqlparser::ast::EnumMember::NamedValue(name, _expr) => {
422 Some(format!("'{}'", name))
423 }
424 })
425 .collect();
426
427 if enum_values.is_empty() {
428 "ENUM()".to_string()
429 } else {
430 format!("ENUM({})", enum_values.join(","))
431 }
432 }
433 _ => format!("{data_type:?}"), }
435}
436
437fn is_column_primary_key(column: &ColumnDef) -> bool {
439 for option in &column.options {
440 if let sqlparser::ast::ColumnOption::Unique { is_primary, .. } = &option.option {
441 if *is_primary {
442 return true;
443 }
444 }
445 }
446 false
447}
448
449fn extract_index_columns(index_columns: &[sqlparser::ast::IndexColumn]) -> Vec<String> {
456 index_columns
457 .iter()
458 .filter_map(|index_col| {
459 match &index_col.column.expr {
460 sqlparser::ast::Expr::Identifier(ident) => Some(strip_backticks(&ident.value)),
461 sqlparser::ast::Expr::CompoundIdentifier(idents) => {
462 idents.last().map(|id| strip_backticks(&id.value))
464 }
465 _ => {
466 Some(strip_backticks(&index_col.column.to_string()))
468 }
469 }
470 })
471 .collect()
472}
473
474fn parse_standalone_indexes(
487 sql_content: &str,
488 tables: &mut HashMap<String, TableDefinition>,
489) -> Result<(), DuckError> {
490 let dialect = MySqlDialect {};
491 let mut index_count = 0;
492
493 let index_statements = extract_create_index_statements(sql_content)?;
495
496 for index_sql in index_statements {
497 debug!("Parsing CREATE INDEX statement: {}", index_sql);
498
499 match Parser::parse_sql(&dialect, &index_sql) {
500 Ok(statements) => {
501 for statement in statements {
502 if let Statement::CreateIndex(create_index) = statement {
503 let index_name = create_index
505 .name
506 .as_ref()
507 .map(ident_to_string)
508 .unwrap_or_else(|| "unnamed_index".to_string());
509
510 let table_name = ident_to_string(&create_index.table_name);
512
513 let columns = extract_index_columns(&create_index.columns);
515
516 if columns.is_empty() {
517 warn!("Index {} has no column definition, skipping", index_name);
518 continue;
519 }
520
521 let is_unique = create_index.unique;
523
524 if let Some(table_def) = tables.get_mut(&table_name) {
526 if table_def.indexes.iter().any(|idx| idx.name == index_name) {
528 debug!("Index {} already exists in table {}, skipping", index_name, table_name);
529 continue;
530 }
531
532 table_def.indexes.push(TableIndex {
534 name: index_name.clone(),
535 columns: columns.clone(),
536 is_primary: false,
537 is_unique,
538 index_type: if is_unique {
539 Some("UNIQUE".to_string())
540 } else {
541 Some("INDEX".to_string())
542 },
543 });
544
545 index_count += 1;
546 debug!(
547 "添加独立索引: {} 到表 {} (列: {:?}, unique: {})",
548 index_name, table_name, columns, is_unique
549 );
550 } else {
551 warn!("Index {} references table {} which does not exist, skipping", index_name, table_name);
552 }
553 }
554 }
555 }
556 Err(e) => {
557 warn!("Failed to parse CREATE INDEX statement: {} - error: {}", index_sql, e);
558 }
559 }
560 }
561
562 if index_count > 0 {
563 info!("Successfully parsed {} standalone CREATE INDEX statements", index_count);
564 }
565
566 Ok(())
567}
568
569fn extract_create_index_statements(sql_content: &str) -> Result<Vec<String>, DuckError> {
573 let mut statements = Vec::new();
574 let mut current_statement = String::new();
575 let mut in_create_index = false;
576
577 let create_index_regex = Regex::new(r"(?i)^\s*CREATE\s+(UNIQUE\s+)?INDEX")
579 .map_err(|e| DuckError::custom(format!("正则表达式编译失败: {}", e)))?;
580
581 for line in sql_content.lines() {
582 let trimmed = line.trim();
583
584 if trimmed.is_empty() || trimmed.starts_with("--") {
586 continue;
587 }
588
589 if !in_create_index && create_index_regex.is_match(line) {
591 in_create_index = true;
592 current_statement.clear();
593 }
594
595 if in_create_index {
596 current_statement.push_str(line);
597 current_statement.push(' ');
598
599 if trimmed.ends_with(';') {
601 statements.push(current_statement.trim().to_string());
602 current_statement.clear();
603 in_create_index = false;
604 }
605 }
606 }
607
608 if in_create_index && !current_statement.trim().is_empty() {
610 statements.push(current_statement.trim().to_string());
611 }
612
613 debug!("Extracted {} CREATE INDEX statements", statements.len());
614 Ok(statements)
615}