use super::schema::*;
use super::types::*;
use crate::config::DatabaseType;
use regex::Regex;
use std::sync::LazyLock;
fn validate_sql_identifier(identifier: &str, identifier_type: &str) -> Result<String, String> {
if identifier.is_empty() {
return Err(format!("{} 不能为空", identifier_type));
}
if identifier.len() > 64 {
return Err(format!("{} 长度不能超过 64 个字符", identifier_type));
}
static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_]*$").unwrap());
if !IDENTIFIER_REGEX.is_match(identifier) {
return Err(format!(
"{} '{}' 包含无效字符,只允许字母、数字和下划线,且不能以数字开头",
identifier_type, identifier
));
}
let reserved_keywords = [
"select",
"insert",
"update",
"delete",
"drop",
"create",
"alter",
"table",
"index",
"from",
"where",
"and",
"or",
"not",
"null",
"primary",
"key",
"foreign",
"references",
"constraint",
"default",
"unique",
"check",
"into",
"values",
"set",
"join",
"left",
"right",
"inner",
"outer",
];
if reserved_keywords.contains(&identifier.to_lowercase().as_str()) {
return Err(format!(
"{} '{}' 是 SQL 保留关键字,不允许使用",
identifier_type, identifier
));
}
Ok(identifier.to_string())
}
fn sanitize_default_value(default: &str) -> String {
let suspicious_patterns = [
"select",
"insert",
"update",
"delete",
"drop",
"create",
"alter",
"exec",
"execute",
"xp_",
"sp_",
"--",
"/*",
"*/",
"chr(",
"char(",
"concat",
"union",
"benchmark",
"sleep",
];
let upper_default = default.to_uppercase();
for pattern in &suspicious_patterns {
if upper_default.contains(pattern) {
tracing::warn!(
"Suspicious pattern detected in default value: '{}', sanitizing",
default
);
return "'***SANITIZED***'".to_string();
}
}
let trimmed = default.trim();
if (trimmed.starts_with('\'') && trimmed.ends_with('\''))
|| (trimmed.starts_with('(') && trimmed.ends_with(')'))
|| trimmed.parse::<i128>().is_ok()
|| trimmed.parse::<f64>().is_ok()
|| trimmed.to_uppercase() == "NULL"
|| trimmed.to_uppercase() == "CURRENT_TIMESTAMP"
|| trimmed.to_uppercase() == "NOW()"
{
trimmed.to_string()
} else {
format!("'{}'", trimmed.replace('\'', "''"))
}
}
pub struct MigrationPlan {
pub migrations: Vec<Migration>,
pub direction: MigrationDirection,
}
#[derive(Debug, Clone)]
pub enum MigrationDirection {
Up,
Down,
}
#[derive(Debug, Clone)]
pub enum MigrationCommand {
Create {
description: String,
directory: String,
},
Up {
target_version: Option<u32>,
},
Down {
target_version: Option<u32>,
},
Status,
Generate {
from_schema: String,
to_schema: String,
output_file: String,
},
}
pub struct SchemaDiffer {
old_schema: Schema,
new_schema: Schema,
}
impl SchemaDiffer {
pub fn new(old_schema: Schema, new_schema: Schema) -> Self {
Self { old_schema, new_schema }
}
pub fn diff(&self) -> Vec<Migration> {
let mut migrations = Vec::new();
let mut migration = Migration::new(1, "Schema changes".to_string());
for new_table in &self.new_schema.tables {
if !self.old_schema.has_table(&new_table.name) {
migration.add_table_change(TableChange::CreateTable(new_table.clone()));
}
}
for old_table in &self.old_schema.tables {
if !self.new_schema.has_table(&old_table.name) {
migration.add_table_change(TableChange::DropTable {
table_name: old_table.name.clone(),
});
}
}
for new_table in &self.new_schema.tables {
if let Some(old_table) = self.old_schema.get_table(&new_table.name) {
let column_changes = self.detect_column_changes(old_table, new_table);
let added_columns = self.detect_added_columns(old_table, new_table);
let removed_columns = self.detect_removed_columns(old_table, new_table);
let added_indexes = self.detect_added_indexes(old_table, new_table);
let removed_indexes = self.detect_removed_indexes(old_table, new_table);
let added_foreign_keys = self.detect_added_foreign_keys(old_table, new_table);
let removed_foreign_keys = self.detect_removed_foreign_keys(old_table, new_table);
if !column_changes.is_empty()
|| !added_columns.is_empty()
|| !removed_columns.is_empty()
|| !added_indexes.is_empty()
|| !removed_indexes.is_empty()
|| !added_foreign_keys.is_empty()
|| !removed_foreign_keys.is_empty()
{
migration.add_table_change(TableChange::AlterTable {
table_name: new_table.name.clone(),
column_changes,
added_columns,
removed_columns,
added_indexes,
removed_indexes,
added_foreign_keys,
removed_foreign_keys,
});
}
}
}
if !migration.table_changes.is_empty() {
migrations.push(migration);
}
migrations
}
fn detect_column_changes(&self, old_table: &Table, new_table: &Table) -> Vec<ColumnAlteration> {
let mut changes = Vec::new();
for new_column in &new_table.columns {
if let Some(old_column) = old_table.columns.iter().find(|c| c.name == new_column.name) {
if old_column.column_type != new_column.column_type {
changes.push(ColumnAlteration::TypeChanged {
column_name: new_column.name.clone(),
old_type: old_column.column_type.clone(),
new_type: new_column.column_type.clone(),
});
}
if old_column.is_nullable != new_column.is_nullable {
changes.push(ColumnAlteration::NullabilityChanged {
column_name: new_column.name.clone(),
old_nullable: old_column.is_nullable,
new_nullable: new_column.is_nullable,
});
}
if old_column.default_value != new_column.default_value {
changes.push(ColumnAlteration::DefaultChanged {
column_name: new_column.name.clone(),
old_default: old_column.default_value.clone(),
new_default: new_column.default_value.clone(),
});
}
}
}
changes
}
fn detect_added_columns(&self, old_table: &Table, new_table: &Table) -> Vec<Column> {
new_table
.columns
.iter()
.filter(|c| !old_table.columns.iter().any(|oc| oc.name == c.name))
.cloned()
.collect()
}
fn detect_removed_columns(&self, old_table: &Table, new_table: &Table) -> Vec<String> {
old_table
.columns
.iter()
.filter(|c| !new_table.columns.iter().any(|nc| nc.name == c.name))
.map(|c| c.name.clone())
.collect()
}
fn detect_added_indexes(&self, old_table: &Table, new_table: &Table) -> Vec<Index> {
new_table
.indexes
.iter()
.filter(|i| !old_table.indexes.iter().any(|oi| oi.name == i.name))
.cloned()
.collect()
}
fn detect_removed_indexes(&self, old_table: &Table, new_table: &Table) -> Vec<String> {
old_table
.indexes
.iter()
.filter(|i| !new_table.indexes.iter().any(|ni| ni.name == i.name))
.map(|i| i.name.clone())
.collect()
}
fn detect_added_foreign_keys(&self, old_table: &Table, new_table: &Table) -> Vec<ForeignKey> {
new_table
.foreign_keys
.iter()
.filter(|fk| !old_table.foreign_keys.iter().any(|ofk| ofk.name == fk.name))
.cloned()
.collect()
}
fn detect_removed_foreign_keys(&self, old_table: &Table, new_table: &Table) -> Vec<String> {
old_table
.foreign_keys
.iter()
.filter(|fk| !new_table.foreign_keys.iter().any(|nfk| nfk.name == fk.name))
.map(|fk| fk.name.clone())
.collect()
}
}
#[derive(Debug, Clone)]
pub struct SqlGenerator {
pub db_type: DatabaseType,
}
impl SqlGenerator {
pub fn new(db_type: DatabaseType) -> Self {
Self { db_type }
}
pub fn generate_column_def(&self, column_type: &ColumnType) -> String {
column_type.to_sql(self.db_type)
}
pub fn generate_create_table_sql(&self, table: &Table) -> String {
let table_name = match validate_sql_identifier(&table.name, "表名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid table name: {}", e);
return format!("-- 错误: {}\n", e);
}
};
let mut sql = format!("CREATE TABLE {} (\n", table_name);
let column_defs: Vec<String> = table
.columns
.iter()
.map(|col| self.generate_column_definition(col, &table.primary_key_columns))
.collect();
sql.push_str(&column_defs.join(",\n"));
if !table.primary_key_columns.is_empty() {
sql.push_str(",\n");
let pk_columns: Vec<String> = table
.primary_key_columns
.iter()
.map(|col| match validate_sql_identifier(col, "主键列名") {
Ok(validated) => validated,
Err(e) => {
tracing::error!("Invalid primary key column: {}", e);
"***INVALID***".to_string()
}
})
.collect();
sql.push_str(&format!(" PRIMARY KEY ({})", pk_columns.join(", ")));
}
sql.push_str("\n);");
for index in &table.indexes {
if !index.is_constraint {
sql.push_str("\n\n");
sql.push_str(&self.generate_create_index_sql(index));
}
}
for fk in &table.foreign_keys {
sql.push_str("\n\n");
sql.push_str(&self.generate_add_foreign_key_sql(fk));
}
sql
}
fn generate_column_definition(&self, column: &Column, _pk_columns: &[String]) -> String {
let column_name = match validate_sql_identifier(&column.name, "列名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid column name: {}", e);
return format!(" -- 错误: {}\n", e);
}
};
let mut def = format!(" {} {}", column_name, column.column_type.to_sql(self.db_type));
if column.is_auto_increment && column.is_primary_key {
match self.db_type {
DatabaseType::MySql => def.push_str(" AUTO_INCREMENT"),
DatabaseType::Sqlite => def.push_str(" PRIMARY KEY AUTOINCREMENT"),
_ => {}
}
}
if !column.is_nullable {
def.push_str(" NOT NULL");
}
if let Some(default) = &column.default_value {
let sanitized_default = sanitize_default_value(default);
def.push_str(&format!(" DEFAULT {}", sanitized_default));
}
if column.is_primary_key && !column.is_auto_increment {
}
def
}
pub fn generate_create_index_sql(&self, index: &Index) -> String {
let index_name = match validate_sql_identifier(&index.name, "索引名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid index name: {}", e);
return format!("-- 错误: {}\n", e);
}
};
let table_name = match validate_sql_identifier(&index.table_name, "表名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid table name in index: {}", e);
return format!("-- 错误: {}\n", e);
}
};
let validated_columns: Vec<String> = index
.columns
.iter()
.map(|col| match validate_sql_identifier(col, "索引列名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid index column: {}", e);
"***INVALID***".to_string()
}
})
.collect();
let unique = if index.is_unique { "UNIQUE " } else { "" };
format!(
"CREATE {}INDEX {} ON {} ({})",
unique,
index_name,
table_name,
validated_columns.join(", ")
)
}
fn generate_add_foreign_key_sql(&self, fk: &ForeignKey) -> String {
let table_name = match validate_sql_identifier(&fk.table_name, "外键表名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid foreign key table name: {}", e);
return format!("-- 错误: {}\n", e);
}
};
let constraint_name = match validate_sql_identifier(&fk.name, "外键约束名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid foreign key constraint name: {}", e);
return format!("-- 错误: {}\n", e);
}
};
let column_name = match validate_sql_identifier(&fk.column_name, "外键列名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid foreign key column: {}", e);
return format!("-- 错误: {}\n", e);
}
};
let referenced_table_name = match validate_sql_identifier(&fk.referenced_table_name, "外键引用表名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid referenced table name: {}", e);
return format!("-- 错误: {}\n", e);
}
};
let referenced_column_name = match validate_sql_identifier(&fk.referenced_column_name, "外键引用列名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid referenced column: {}", e);
return format!("-- 错误: {}\n", e);
}
};
let mut sql = format!(
"ALTER TABLE {} ADD CONSTRAINT {} FOREIGN KEY ({}) REFERENCES {}({})",
table_name, constraint_name, column_name, referenced_table_name, referenced_column_name
);
if let Some(on_delete) = &fk.on_delete {
sql.push_str(&format!(" ON DELETE {}", on_delete));
}
if let Some(on_update) = &fk.on_update {
sql.push_str(&format!(" ON UPDATE {}", on_update));
}
sql.push(';');
sql
}
pub fn generate_drop_table_sql(&self, table_name: &str) -> String {
let validated_name = match validate_sql_identifier(table_name, "表名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid table name: {}", e);
return format!("-- 错误: {}\n", e);
}
};
format!("DROP TABLE {};", validated_name)
}
pub fn generate_add_column_sql(&self, table_name: &str, column: &Column) -> String {
let validated_table_name = match validate_sql_identifier(table_name, "表名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid table name: {}", e);
return format!("-- 错误: {}\n", e);
}
};
let col_def = self.generate_column_definition(column, &Vec::new());
format!(
"ALTER TABLE {} ADD {};",
validated_table_name,
col_def.trim_start_matches(" ")
)
}
pub fn generate_drop_column_sql(&self, table_name: &str, column_name: &str) -> String {
let validated_table_name = match validate_sql_identifier(table_name, "表名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid table name: {}", e);
return format!("-- 错误: {}\n", e);
}
};
let validated_column_name = match validate_sql_identifier(column_name, "列名") {
Ok(name) => name,
Err(e) => {
tracing::error!("Invalid column name: {}", e);
return format!("-- 错误: {}\n", e);
}
};
match self.db_type {
DatabaseType::MySql => {
format!(
"ALTER TABLE {} DROP COLUMN {};",
validated_table_name, validated_column_name
)
}
DatabaseType::Postgres => {
format!(
"ALTER TABLE {} DROP COLUMN {};",
validated_table_name, validated_column_name
)
}
DatabaseType::Sqlite => {
format!(
"-- SQLite 不支持直接删除列,请手动重建表 {}
ALTER TABLE {} DROP COLUMN {};",
validated_table_name, validated_table_name, validated_column_name
)
}
}
}
pub fn generate_migration_sql(&self, migration: &Migration) -> String {
let mut sql = String::new();
for change in &migration.table_changes {
match change {
TableChange::CreateTable(table) => {
sql.push_str(&format!("-- 创建表: {}\n", table.name));
sql.push_str(&self.generate_create_table_sql(table));
sql.push_str("\n\n");
}
TableChange::DropTable { table_name } => {
sql.push_str(&format!("-- 删除表: {}\n", table_name));
sql.push_str(&self.generate_drop_table_sql(table_name));
sql.push_str("\n\n");
}
TableChange::AlterTable {
table_name,
added_columns,
removed_columns,
added_indexes,
removed_indexes,
added_foreign_keys,
removed_foreign_keys,
..
} => {
sql.push_str(&format!("-- 修改表: {}\n", table_name));
for col in added_columns {
sql.push_str(&format!("-- 添加列: {}\n", col.name));
sql.push_str(&self.generate_add_column_sql(table_name, col));
sql.push('\n');
}
for col_name in removed_columns {
sql.push_str(&format!("-- 删除列: {}\n", col_name));
sql.push_str(&self.generate_drop_column_sql(table_name, col_name));
sql.push('\n');
}
for index in added_indexes {
sql.push_str(&format!("-- 添加索引: {}\n", index.name));
sql.push_str(&self.generate_create_index_sql(index));
sql.push('\n');
}
for index_name in removed_indexes {
sql.push_str(&format!("-- 删除索引: {}\n", index_name));
sql.push_str(&format!("DROP INDEX {};\n", index_name));
}
for fk in added_foreign_keys {
sql.push_str(&format!("-- 添加外键: {}\n", fk.name));
sql.push_str(&self.generate_add_foreign_key_sql(fk));
sql.push('\n');
}
for fk_name in removed_foreign_keys {
sql.push_str(&format!("-- 删除外键: {}\n", fk_name));
sql.push_str(&format!("ALTER TABLE {} DROP CONSTRAINT {};\n", table_name, fk_name));
}
sql.push('\n');
}
}
}
sql.trim_end().to_string()
}
}
#[derive(Debug, Clone)]
pub(crate) struct RustEntityParser;
impl RustEntityParser {
pub fn parse_entity(entity_code: &str, table_name: &str) -> Result<Table, String> {
let columns = Self::extract_columns_from_code(entity_code)?;
let primary_key_columns = columns
.iter()
.filter(|c| c.is_primary_key)
.map(|c| c.name.clone())
.collect();
Ok(Table {
name: table_name.to_string(),
columns,
primary_key_columns,
indexes: Vec::new(),
foreign_keys: Vec::new(),
comment: None,
})
}
fn extract_columns_from_code(entity_code: &str) -> Result<Vec<Column>, String> {
let mut columns = Vec::new();
let lines: Vec<&str> = entity_code.lines().collect();
let mut current_field_name: Option<String> = None;
let mut current_field_type: Option<String> = None;
let mut current_column_type: Option<ColumnType> = None;
let mut field_column_type: Option<ColumnType> = None; let mut is_primary_key = false;
let mut is_nullable = true;
let mut is_auto_increment = false;
for line in &lines {
let line = line.trim();
if let Some((field_name, field_type)) = Self::extract_field_and_type(line) {
if let Some(ref prev_field_name) = current_field_name {
let col_type = field_column_type
.take()
.or_else(|| Self::infer_column_type(¤t_field_type));
if let Some(type_result) = col_type {
if !columns.iter().any(|c: &Column| c.name == *prev_field_name) {
columns.push(Column {
name: prev_field_name.clone(),
column_type: type_result,
is_primary_key,
is_nullable,
has_default: false,
default_value: None,
is_auto_increment,
comment: None,
});
}
}
is_primary_key = false;
is_nullable = true;
is_auto_increment = false;
}
field_column_type = current_column_type.take();
current_field_name = Some(field_name);
current_field_type = Some(field_type);
continue;
}
if line.contains("column_type") {
current_column_type = Self::extract_column_type(line);
}
if line.contains("primary_key") {
is_primary_key = true;
}
if line.contains("NotNull") || line.contains("not_null") {
is_nullable = false;
}
if line.contains("AutoIncrement") || line.contains("auto_increment") {
is_auto_increment = true;
}
if line.starts_with("#[") {
continue;
}
}
if let Some(ref field_name) = current_field_name {
let col_type = field_column_type
.take()
.or_else(|| Self::infer_column_type(¤t_field_type));
if let Some(type_result) = col_type {
columns.push(Column {
name: field_name.clone(),
column_type: type_result,
is_primary_key,
is_nullable,
has_default: false,
default_value: None,
is_auto_increment,
comment: None,
});
}
}
if columns.is_empty() {
return Err("未能解析到任何列".to_string());
}
Ok(columns)
}
fn extract_field_and_type(line: &str) -> Option<(String, String)> {
let trimmed = line.trim();
if trimmed.starts_with("#[") {
return None;
}
if trimmed.starts_with("pub struct ") || trimmed.starts_with("struct ") {
return None;
}
let colon_idx = trimmed.find(':')?;
let before_colon = &trimmed[..colon_idx];
let after_colon = &trimmed[colon_idx + 1..];
let mut field_name = before_colon.trim_end().trim_end_matches(',').trim().to_string();
if field_name.starts_with("pub ") {
field_name = field_name[4..].to_string();
}
if field_name.starts_with("#[") || field_name.starts_with("fn ") || field_name.is_empty() {
return None;
}
let mut type_str = after_colon.trim();
let type_end = type_str
.find(',')
.or_else(|| type_str.find('}'))
.unwrap_or(type_str.len());
type_str = &type_str[..type_end];
if field_name.is_empty() || type_str.is_empty() {
return None;
}
Some((field_name.to_string(), type_str.to_string()))
}
fn extract_column_type(attr_line: &str) -> Option<ColumnType> {
if let Some(start) = attr_line.find("column_type") {
let after = &attr_line[start..];
if let Some(eq_idx) = after.find('=') {
let type_str = &after[eq_idx + 1..];
if let Some(quote_start) = type_str.find('"') {
if let Some(quote_end) = type_str[quote_start + 1..].find('"') {
let type_content = &type_str[quote_start + 1..quote_start + 1 + quote_end];
return Some(Self::parse_column_type_str(type_content));
}
}
}
}
None
}
fn parse_column_type_str(type_str: &str) -> ColumnType {
match type_str {
"Integer" | "Int" | "i32" => ColumnType::Integer,
"BigInteger" | "BigInt" => ColumnType::BigInteger,
"String" => ColumnType::String(Some(255)),
s if s.starts_with("String(") => {
if let Some(len_str) = s.strip_prefix("String(").and_then(|s| s.strip_suffix(')')) {
if let Ok(len) = len_str.parse() {
return ColumnType::String(Some(len));
}
}
ColumnType::String(Some(255))
}
"Text" => ColumnType::Text,
"Boolean" | "Bool" | "bool" => ColumnType::Boolean,
"Float" | "f32" => ColumnType::Float,
"Double" | "f64" => ColumnType::Double,
"Date" => ColumnType::Date,
"Time" => ColumnType::Time,
"DateTime" | "DateTimeUtc" => ColumnType::DateTime,
"Timestamp" | "TimestampUtc" => ColumnType::Timestamp,
"Json" | "JsonValue" => ColumnType::Json,
"Binary" | "Vec<u8>" => ColumnType::Binary,
_ => ColumnType::Custom(type_str.to_string()),
}
}
fn infer_column_type(field_type: &Option<String>) -> Option<ColumnType> {
let type_str = field_type.as_ref()?.to_lowercase();
let inner_type = if type_str.starts_with("option<") {
if let Some(end) = type_str.find('>') {
&type_str[7..end]
} else {
&type_str
}
} else {
&type_str
};
match inner_type {
t if t.contains("i32") || t == "integer" || t == "int" => Some(ColumnType::Integer),
t if t.contains("i64") || t == "biginteger" || t == "bigint" => Some(ColumnType::BigInteger),
t if t.contains("string") || t.contains("&str") => {
if let Some(len_start) = t.find('<') {
if let Some(len_end) = t[len_start..].find('>') {
let len_str = &t[len_start + 1..len_start + len_end];
if let Ok(len) = len_str.parse() {
return Some(ColumnType::String(Some(len)));
}
}
}
Some(ColumnType::String(Some(255)))
}
t if t.contains("text") || t.contains("string") => Some(ColumnType::Text),
t if t.contains("bool") => Some(ColumnType::Boolean),
t if t.contains("f32") | t.contains("float") => Some(ColumnType::Float),
t if t.contains("f64") | t.contains("double") => Some(ColumnType::Double),
t if t.contains("date") && t.contains("time") => Some(ColumnType::DateTime),
t if t.contains("date") => Some(ColumnType::Date),
t if t.contains("time") => Some(ColumnType::Time),
t if t.contains("timestamp") => Some(ColumnType::Timestamp),
t if t.contains("json") => Some(ColumnType::Json),
t if t.contains("vec<u8>") || t.contains("binary") => Some(ColumnType::Binary),
_ => None,
}
}
pub fn generate_migration_sql(
entity_code: &str,
table_name: &str,
db_type: DatabaseType,
) -> Result<String, String> {
let table = Self::parse_entity(entity_code, table_name)?;
let generator = SqlGenerator::new(db_type);
Ok(generator.generate_create_table_sql(&table))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_column_type_to_sql() {
let pg = SqlGenerator::new(DatabaseType::Postgres);
let mysql = SqlGenerator::new(DatabaseType::MySql);
let sqlite = SqlGenerator::new(DatabaseType::Sqlite);
assert_eq!(pg.generate_column_def(&ColumnType::Integer), "INTEGER");
assert_eq!(mysql.generate_column_def(&ColumnType::Integer), "INTEGER");
assert_eq!(sqlite.generate_column_def(&ColumnType::Integer), "INTEGER");
assert_eq!(pg.generate_column_def(&ColumnType::Boolean), "BOOLEAN");
assert_eq!(mysql.generate_column_def(&ColumnType::Boolean), "BOOLEAN");
assert_eq!(sqlite.generate_column_def(&ColumnType::Boolean), "INTEGER");
}
#[test]
fn test_schema_diff_new_table() {
let old_schema = Schema::new(DatabaseType::Postgres);
let mut new_schema = Schema::new(DatabaseType::Postgres);
let users_table = Table {
name: "users".to_string(),
columns: vec![Column {
name: "id".to_string(),
column_type: ColumnType::Integer,
is_primary_key: true,
is_nullable: false,
has_default: false,
default_value: None,
is_auto_increment: true,
comment: None,
}],
primary_key_columns: vec!["id".to_string()],
indexes: vec![],
foreign_keys: vec![],
comment: None,
};
new_schema.add_table(users_table);
let differ = SchemaDiffer::new(old_schema, new_schema);
let migrations = differ.diff();
assert_eq!(migrations.len(), 1);
assert_eq!(migrations[0].table_changes.len(), 1);
if let TableChange::CreateTable(table) = &migrations[0].table_changes[0] {
assert_eq!(table.name, "users");
} else {
unreachable!("Expected CreateTable change");
}
}
#[test]
fn test_schema_diff_drop_table() {
let mut old_schema = Schema::new(DatabaseType::Postgres);
let new_schema = Schema::new(DatabaseType::Postgres);
let users_table = Table {
name: "users".to_string(),
columns: vec![],
primary_key_columns: vec![],
indexes: vec![],
foreign_keys: vec![],
comment: None,
};
old_schema.add_table(users_table);
let differ = SchemaDiffer::new(old_schema, new_schema);
let migrations = differ.diff();
assert_eq!(migrations.len(), 1);
assert_eq!(migrations[0].table_changes.len(), 1);
if let TableChange::DropTable { table_name } = &migrations[0].table_changes[0] {
assert_eq!(table_name, "users");
} else {
unreachable!("Expected DropTable change");
}
}
#[test]
fn test_sql_generation() {
let pg = SqlGenerator::new(DatabaseType::Postgres);
let table = Table {
name: "users".to_string(),
columns: vec![
Column {
name: "id".to_string(),
column_type: ColumnType::Integer,
is_primary_key: true,
is_nullable: false,
has_default: false,
default_value: None,
is_auto_increment: true,
comment: None,
},
Column {
name: "name".to_string(),
column_type: ColumnType::String(Some(255)),
is_primary_key: false,
is_nullable: false,
has_default: false,
default_value: None,
is_auto_increment: false,
comment: None,
},
],
primary_key_columns: vec!["id".to_string()],
indexes: vec![],
foreign_keys: vec![],
comment: None,
};
let sql = pg.generate_create_table_sql(&table);
assert!(sql.contains("CREATE TABLE users"));
assert!(sql.contains("id INTEGER"));
assert!(sql.contains("name VARCHAR(255)"));
assert!(sql.contains("NOT NULL"));
assert!(sql.contains("PRIMARY KEY (id)"));
}
#[test]
fn test_rust_entity_parser_basic() {
let entity_code = r#"
#[sea_orm(table_name = "users")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
#[sea_orm(column_type = "String(255)")]
pub name: String,
#[sea_orm(column_type = "Text")]
pub bio: Option<String>,
}
"#;
let table = RustEntityParser::parse_entity(entity_code, "users").expect("Failed to parse entity code");
assert_eq!(table.name, "users");
assert_eq!(table.columns.len(), 3);
assert_eq!(table.primary_key_columns, vec!["id"]);
let id_col = table
.columns
.iter()
.find(|c| c.name == "id")
.expect("id column should exist");
assert_eq!(id_col.column_type, ColumnType::Integer);
assert!(id_col.is_primary_key);
let name_col = table
.columns
.iter()
.find(|c| c.name == "name")
.expect("name column should exist");
assert_eq!(name_col.column_type, ColumnType::String(Some(255)));
}
#[test]
fn test_rust_entity_generate_migration() {
let entity_code = r#"
#[sea_orm(table_name = "posts")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i64,
#[sea_orm(column_type = "String(255)")]
pub title: String,
#[sea_orm(column_type = "Text")]
pub content: String,
#[sea_orm(column_type = "DateTime")]
pub created_at: DateTimeUtc,
}
"#;
let sql = RustEntityParser::generate_migration_sql(entity_code, "posts", DatabaseType::Postgres)
.expect("Failed to generate migration SQL");
assert!(sql.contains("CREATE TABLE posts"));
assert!(sql.contains("id BIGINT"));
assert!(sql.contains("title VARCHAR(255)"));
assert!(sql.contains("content TEXT"));
assert!(sql.contains("created_at TIMESTAMP"));
}
#[test]
fn test_parse_column_type_string() {
assert_eq!(RustEntityParser::parse_column_type_str("Integer"), ColumnType::Integer);
assert_eq!(
RustEntityParser::parse_column_type_str("String(100)"),
ColumnType::String(Some(100))
);
assert_eq!(RustEntityParser::parse_column_type_str("Text"), ColumnType::Text);
assert_eq!(RustEntityParser::parse_column_type_str("Boolean"), ColumnType::Boolean);
assert_eq!(
RustEntityParser::parse_column_type_str("DateTime"),
ColumnType::DateTime
);
assert_eq!(RustEntityParser::parse_column_type_str("Json"), ColumnType::Json);
}
}