use super::types::{TableColumn, TableDefinition, TableIndex};
use crate::error::DuckError;
use regex::Regex;
use sqlparser::ast::{ColumnDef, DataType, Statement, TableConstraint};
use sqlparser::dialect::MySqlDialect;
use sqlparser::parser::Parser;
use std::collections::HashMap;
use tracing::{debug, info, warn};
#[inline]
fn strip_backticks(s: &str) -> String {
s.trim_matches('`').to_string()
}
#[inline]
fn ident_to_string<T: ToString>(ident: &T) -> String {
strip_backticks(&ident.to_string())
}
pub fn parse_sql_tables(sql_content: &str) -> Result<HashMap<String, TableDefinition>, DuckError> {
let mut tables = HashMap::new();
let create_table_statements = extract_create_table_statements_with_regex(sql_content)?;
let dialect = MySqlDialect {};
for create_table_sql in create_table_statements {
debug!("Parsing CREATE TABLE statement: {}", create_table_sql);
match Parser::parse_sql(&dialect, &create_table_sql) {
Ok(statements) => {
for statement in statements {
if let Statement::CreateTable(create_table) = statement {
let table_name = ident_to_string(&create_table.name);
debug!("Parsing table: {}", table_name);
let mut table_columns = Vec::new();
let mut table_indexes = Vec::new();
let mut primary_key_columns = Vec::new();
for column in &create_table.columns {
let column_def = parse_column_definition(column)?;
if is_column_primary_key(column) {
primary_key_columns.push(ident_to_string(&column.name));
}
table_columns.push(column_def);
}
if !primary_key_columns.is_empty() {
table_indexes.push(TableIndex {
name: "PRIMARY".to_string(),
columns: primary_key_columns,
is_primary: true,
is_unique: true,
index_type: Some("PRIMARY".to_string()),
});
}
for constraint in &create_table.constraints {
if let Some(index) = parse_table_constraint(constraint)? {
table_indexes.push(index);
}
}
let table_def = TableDefinition {
name: table_name.clone(),
columns: table_columns,
indexes: table_indexes,
engine: None, charset: None, };
tables.insert(table_name, table_def);
}
}
}
Err(e) => {
warn!("Failed to parse SQL statement: {} - error: {}", create_table_sql, e);
}
}
}
parse_standalone_indexes(sql_content, &mut tables)?;
info!("Successfully parsed {} tables", tables.len());
Ok(tables)
}
fn extract_create_table_statements_with_regex(sql_content: &str) -> Result<Vec<String>, DuckError> {
let use_regex = Regex::new(r"(?i)^\s*USE\s+[^;]+;\s*$")
.map_err(|e| DuckError::custom(format!("正则表达式编译失败: {e}")))?;
let lines: Vec<&str> = sql_content.lines().collect();
let mut start_parsing_from_line = 0;
for (line_idx, line) in lines.iter().enumerate() {
if use_regex.is_match(line) {
debug!("Found USE statement at line {}: {}", line_idx + 1, line);
start_parsing_from_line = line_idx + 1; break;
}
}
if start_parsing_from_line == 0 {
debug!("No USE statement found, parsing entire file from the beginning");
}
let content_to_parse = if start_parsing_from_line < lines.len() {
lines[start_parsing_from_line..].join("\n")
} else {
sql_content.to_string()
};
extract_create_table_statements_from_content(&content_to_parse)
}
fn extract_create_table_statements_from_content(content: &str) -> Result<Vec<String>, DuckError> {
let mut statements = Vec::new();
let create_table_regex = Regex::new(r"(?i)^\s*CREATE\s+TABLE")
.map_err(|e| DuckError::custom(format!("正则表达式编译失败: {e}")))?;
let lines: Vec<&str> = content.lines().collect();
let mut current_statement = String::new();
let mut in_create_table = false;
let mut paren_count = 0;
let mut in_string = false;
let mut escape_next = false;
for line in lines {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with("--") || trimmed.starts_with("/*") {
continue;
}
if !in_create_table && create_table_regex.is_match(line) {
in_create_table = true;
current_statement.clear();
paren_count = 0;
in_string = false;
escape_next = false;
}
if in_create_table {
current_statement.push_str(line);
current_statement.push('\n');
for ch in line.chars() {
if escape_next {
escape_next = false;
continue;
}
match ch {
'\\' if in_string => {
escape_next = true;
}
'\'' | '"' | '`' => {
in_string = !in_string;
}
'(' if !in_string => {
paren_count += 1;
}
')' if !in_string => {
paren_count -= 1;
}
';' if !in_string && paren_count <= 0 => {
statements.push(current_statement.trim().to_string());
current_statement.clear();
in_create_table = false;
paren_count = 0;
break;
}
_ => {}
}
}
}
}
if in_create_table && !current_statement.trim().is_empty() {
statements.push(current_statement.trim().to_string());
}
debug!("Extracted {} CREATE TABLE statements", statements.len());
Ok(statements)
}
fn parse_column_definition(column: &ColumnDef) -> Result<TableColumn, DuckError> {
let column_name = ident_to_string(&column.name);
let data_type = format_data_type(&column.data_type);
let mut nullable = true;
let mut default_value = None;
let mut comment = None;
let mut auto_increment = false;
for option in &column.options {
match &option.option {
sqlparser::ast::ColumnOption::NotNull => {
nullable = false;
}
sqlparser::ast::ColumnOption::Default(expr) => {
default_value = Some(format_default_value(expr));
}
sqlparser::ast::ColumnOption::Comment(c) => {
comment = Some(c.clone());
}
sqlparser::ast::ColumnOption::Unique { is_primary, .. } => {
if *is_primary {
nullable = false; }
}
sqlparser::ast::ColumnOption::DialectSpecific(tokens) => {
let token_str = tokens
.iter()
.map(|t| t.to_string())
.collect::<Vec<_>>()
.join(" ")
.to_uppercase();
if token_str.contains("AUTO_INCREMENT") {
auto_increment = true;
}
}
_ => {}
}
}
Ok(TableColumn {
name: column_name,
data_type,
nullable,
default_value,
auto_increment,
comment,
})
}
fn parse_table_constraint(constraint: &TableConstraint) -> Result<Option<TableIndex>, DuckError> {
match constraint {
TableConstraint::PrimaryKey { columns, .. } => {
let column_names: Vec<String> = columns.iter().map(ident_to_string).collect();
Ok(Some(TableIndex {
name: "PRIMARY".to_string(),
columns: column_names,
is_primary: true,
is_unique: true,
index_type: Some("PRIMARY".to_string()),
}))
}
TableConstraint::Unique { columns, name, .. } => {
let column_names: Vec<String> = columns.iter().map(ident_to_string).collect();
let index_name = name
.as_ref()
.map(ident_to_string)
.unwrap_or_else(|| format!("unique_{}", column_names.join("_")));
Ok(Some(TableIndex {
name: index_name,
columns: column_names,
is_primary: false,
is_unique: true,
index_type: Some("UNIQUE".to_string()),
}))
}
TableConstraint::Index { name, columns, .. } => {
let column_names: Vec<String> = columns.iter().map(ident_to_string).collect();
let index_name = name
.as_ref()
.map(ident_to_string)
.unwrap_or_else(|| format!("idx_{}", column_names.join("_")));
Ok(Some(TableIndex {
name: index_name,
columns: column_names,
is_primary: false,
is_unique: false,
index_type: Some("INDEX".to_string()),
}))
}
_ => Ok(None),
}
}
fn format_default_value(expr: &sqlparser::ast::Expr) -> String {
debug!("format_default_value called, expression: {:?}", expr);
match expr {
sqlparser::ast::Expr::Function(function) => {
let function_name = function.name.to_string();
debug!("Detected function call: {}", function_name);
match function_name.to_uppercase().as_str() {
"CURRENT_TIMESTAMP" | "NOW" | "CURRENT_DATE" | "CURRENT_TIME"
| "LOCALTIMESTAMP" | "LOCALTIME" => {
debug!("Recognized as MySQL datetime function, returning: {}", function_name);
function_name
}
_ => {
debug!("Other function, using default format: {}", function_name);
format!("{expr}")
}
}
}
sqlparser::ast::Expr::Value(value_with_span) => {
debug!("Detected value type: {:?}", value_with_span);
match &value_with_span.value {
sqlparser::ast::Value::SingleQuotedString(s) => {
debug!("String value: {} -> '{}'", s, s);
format!("'{}'", s)
}
sqlparser::ast::Value::Number(_, _) => {
debug!("Numeric value");
format!("{expr}")
}
sqlparser::ast::Value::Null => {
debug!("NULL value");
"NULL".to_string()
}
sqlparser::ast::Value::Boolean(b) => {
debug!("Boolean value: {}", b);
b.to_string()
}
_ => {
debug!("Other value type");
format!("{expr}")
}
}
}
_ => {
debug!("Other expression type");
format!("{expr}")
}
}
}
fn format_data_type(data_type: &DataType) -> String {
match data_type {
DataType::Char(size) => {
if let Some(size) = size {
format!("CHAR({size})")
} else {
"CHAR".to_string()
}
}
DataType::Varchar(size) => {
if let Some(size) = size {
format!("VARCHAR({size})")
} else {
"VARCHAR".to_string()
}
}
DataType::Text => "TEXT".to_string(),
DataType::Int(_) => "INT".to_string(),
DataType::BigInt(_) => "BIGINT".to_string(),
DataType::TinyInt(_) => "TINYINT".to_string(),
DataType::SmallInt(_) => "SMALLINT".to_string(),
DataType::MediumInt(_) => "MEDIUMINT".to_string(),
DataType::Float(_) => "FLOAT".to_string(),
DataType::Double(_) => "DOUBLE".to_string(),
DataType::Decimal(exact_number_info) => match exact_number_info {
sqlparser::ast::ExactNumberInfo::PrecisionAndScale(precision, scale) => {
format!("DECIMAL({precision},{scale})")
}
sqlparser::ast::ExactNumberInfo::Precision(precision) => {
format!("DECIMAL({precision})")
}
sqlparser::ast::ExactNumberInfo::None => "DECIMAL".to_string(),
},
DataType::Boolean => "BOOLEAN".to_string(),
DataType::Date => "DATE".to_string(),
DataType::Time(_, _) => "TIME".to_string(),
DataType::Timestamp(_, _) => "TIMESTAMP".to_string(),
DataType::Datetime(_) => "DATETIME".to_string(),
DataType::JSON => "JSON".to_string(),
DataType::Enum(variants, _max_length) => {
let enum_values: Vec<String> = variants
.iter()
.filter_map(|variant| match variant {
sqlparser::ast::EnumMember::Name(name) => Some(format!("'{}'", name)),
sqlparser::ast::EnumMember::NamedValue(name, _expr) => {
Some(format!("'{}'", name))
}
})
.collect();
if enum_values.is_empty() {
"ENUM()".to_string()
} else {
format!("ENUM({})", enum_values.join(","))
}
}
_ => format!("{data_type:?}"), }
}
fn is_column_primary_key(column: &ColumnDef) -> bool {
for option in &column.options {
if let sqlparser::ast::ColumnOption::Unique { is_primary, .. } = &option.option {
if *is_primary {
return true;
}
}
}
false
}
fn extract_index_columns(index_columns: &[sqlparser::ast::IndexColumn]) -> Vec<String> {
index_columns
.iter()
.filter_map(|index_col| {
match &index_col.column.expr {
sqlparser::ast::Expr::Identifier(ident) => Some(strip_backticks(&ident.value)),
sqlparser::ast::Expr::CompoundIdentifier(idents) => {
idents.last().map(|id| strip_backticks(&id.value))
}
_ => {
Some(strip_backticks(&index_col.column.to_string()))
}
}
})
.collect()
}
fn parse_standalone_indexes(
sql_content: &str,
tables: &mut HashMap<String, TableDefinition>,
) -> Result<(), DuckError> {
let dialect = MySqlDialect {};
let mut index_count = 0;
let index_statements = extract_create_index_statements(sql_content)?;
for index_sql in index_statements {
debug!("Parsing CREATE INDEX statement: {}", index_sql);
match Parser::parse_sql(&dialect, &index_sql) {
Ok(statements) => {
for statement in statements {
if let Statement::CreateIndex(create_index) = statement {
let index_name = create_index
.name
.as_ref()
.map(ident_to_string)
.unwrap_or_else(|| "unnamed_index".to_string());
let table_name = ident_to_string(&create_index.table_name);
let columns = extract_index_columns(&create_index.columns);
if columns.is_empty() {
warn!("Index {} has no column definition, skipping", index_name);
continue;
}
let is_unique = create_index.unique;
if let Some(table_def) = tables.get_mut(&table_name) {
if table_def.indexes.iter().any(|idx| idx.name == index_name) {
debug!("Index {} already exists in table {}, skipping", index_name, table_name);
continue;
}
table_def.indexes.push(TableIndex {
name: index_name.clone(),
columns: columns.clone(),
is_primary: false,
is_unique,
index_type: if is_unique {
Some("UNIQUE".to_string())
} else {
Some("INDEX".to_string())
},
});
index_count += 1;
debug!(
"添加独立索引: {} 到表 {} (列: {:?}, unique: {})",
index_name, table_name, columns, is_unique
);
} else {
warn!("Index {} references table {} which does not exist, skipping", index_name, table_name);
}
}
}
}
Err(e) => {
warn!("Failed to parse CREATE INDEX statement: {} - error: {}", index_sql, e);
}
}
}
if index_count > 0 {
info!("Successfully parsed {} standalone CREATE INDEX statements", index_count);
}
Ok(())
}
fn extract_create_index_statements(sql_content: &str) -> Result<Vec<String>, DuckError> {
let mut statements = Vec::new();
let mut current_statement = String::new();
let mut in_create_index = false;
let create_index_regex = Regex::new(r"(?i)^\s*CREATE\s+(UNIQUE\s+)?INDEX")
.map_err(|e| DuckError::custom(format!("正则表达式编译失败: {}", e)))?;
for line in sql_content.lines() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with("--") {
continue;
}
if !in_create_index && create_index_regex.is_match(line) {
in_create_index = true;
current_statement.clear();
}
if in_create_index {
current_statement.push_str(line);
current_statement.push(' ');
if trimmed.ends_with(';') {
statements.push(current_statement.trim().to_string());
current_statement.clear();
in_create_index = false;
}
}
}
if in_create_index && !current_statement.trim().is_empty() {
statements.push(current_statement.trim().to_string());
}
debug!("Extracted {} CREATE INDEX statements", statements.len());
Ok(statements)
}