use std::sync::OnceLock;
use parking_lot::RwLock;
use crate::database::Database;
use crate::error::{Error, Result};
use crate::{tide_info, tide_warn, tide_debug};
use sea_orm::{
ConnectionTrait, DbBackend, EntityTrait, Statement,
schema::{Schema, SchemaBuilder},
sea_query::{
Table, ColumnDef as SeaColumnDef, Alias, Expr,
ColumnType as SeaColumnType, PostgresQueryBuilder,
MysqlQueryBuilder, SqliteQueryBuilder,
},
};
pub type EntityRegistrationFn = Box<dyn Fn(SchemaBuilder) -> SchemaBuilder + Send + Sync>;
static ENTITY_REGISTRY: OnceLock<RwLock<Vec<EntityRegistrationFn>>> = OnceLock::new();
static DIRECT_SCHEMAS: OnceLock<RwLock<Vec<ModelSchema>>> = OnceLock::new();
fn get_entity_registry() -> &'static RwLock<Vec<EntityRegistrationFn>> {
ENTITY_REGISTRY.get_or_init(|| RwLock::new(Vec::new()))
}
fn get_direct_schemas() -> &'static RwLock<Vec<ModelSchema>> {
DIRECT_SCHEMAS.get_or_init(|| RwLock::new(Vec::new()))
}
pub trait SyncModel {
fn sync_schema() -> ModelSchema;
fn register_for_sync() {
SyncRegistry::register(Self::sync_schema());
}
}
pub trait RegisterModels {
fn register_all();
}
impl RegisterModels for () {
fn register_all() {}
}
macro_rules! impl_register_models_tuples {
($first:ident) => {
impl<$first: SyncModel> RegisterModels for ($first,) {
fn register_all() {
$first::register_for_sync();
}
}
};
($first:ident, $($rest:ident),+) => {
impl_register_models_tuples!($($rest),+);
impl<$first: SyncModel, $($rest: SyncModel),+> RegisterModels for ($first, $($rest),+) {
fn register_all() {
$first::register_for_sync();
$($rest::register_for_sync();)+
}
}
};
}
impl_register_models_tuples!(
T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18,
T19, T20, T21, T22, T23, T24, T25, T26, T27, T28, T29, T30, T31, T32, T33, T34, T35,
T36, T37, T38, T39, T40, T41, T42, T43, T44, T45, T46, T47, T48, T49, T50, T51, T52,
T53, T54, T55, T56, T57, T58, T59, T60, T61, T62, T63, T64, T65, T66, T67, T68, T69,
T70, T71, T72, T73, T74, T75, T76, T77, T78, T79, T80, T81, T82, T83, T84, T85, T86,
T87, T88, T89, T90, T91, T92, T93, T94, T95, T96, T97, T98, T99, T100, T101, T102,
T103, T104, T105, T106, T107, T108, T109, T110, T111, T112, T113, T114, T115, T116,
T117, T118, T119, T120, T121, T122, T123, T124, T125, T126, T127, T128, T129, T130,
T131, T132, T133, T134, T135, T136, T137, T138, T139, T140, T141, T142, T143, T144,
T145, T146, T147, T148, T149, T150, T151, T152, T153, T154, T155, T156, T157, T158,
T159, T160, T161, T162, T163, T164, T165, T166, T167, T168, T169, T170, T171, T172,
T173, T174, T175, T176, T177, T178, T179, T180, T181, T182, T183, T184, T185, T186,
T187, T188, T189, T190, T191, T192, T193, T194, T195, T196, T197, T198, T199, T200
);
pub struct SyncRegistry;
impl SyncRegistry {
pub fn register_entity<E: EntityTrait + Default + 'static>() {
let registry = get_entity_registry();
let mut fns = registry.write();
let register_fn: EntityRegistrationFn = Box::new(|builder: SchemaBuilder| {
builder.register(E::default())
});
fns.push(register_fn);
}
pub fn build_schema_builder(backend: DbBackend) -> SchemaBuilder {
let registry = get_entity_registry();
let fns = registry.read();
let schema = Schema::new(backend);
let mut builder = schema.builder();
for register_fn in fns.iter() {
builder = register_fn(builder);
}
builder
}
pub fn entity_count() -> usize {
let registry = get_entity_registry();
let fns = registry.read();
fns.len()
}
pub fn legacy_count() -> usize {
let direct = get_direct_schemas();
let schemas = direct.read();
schemas.len()
}
pub fn clear() {
let registry = get_entity_registry();
let mut fns = registry.write();
fns.clear();
let direct = get_direct_schemas();
let mut schemas = direct.write();
schemas.clear();
}
pub fn register(schema: ModelSchema) {
let direct = get_direct_schemas();
let mut schemas = direct.write();
if !schemas.iter().any(|s| s.table_name == schema.table_name) {
schemas.push(schema);
}
}
pub fn get_all() -> Vec<ModelSchema> {
let direct = get_direct_schemas();
let schemas = direct.read();
schemas.clone()
}
}
#[derive(Debug, Clone)]
pub struct ColumnDef {
pub name: String,
pub col_type: String,
pub nullable: bool,
pub primary_key: bool,
pub auto_increment: bool,
pub default: Option<String>,
}
impl ColumnDef {
pub fn new(name: impl Into<String>, col_type: impl Into<String>) -> Self {
Self {
name: name.into(),
col_type: col_type.into(),
nullable: true,
primary_key: false,
auto_increment: false,
default: None,
}
}
pub fn primary_key(mut self) -> Self {
self.primary_key = true;
self.nullable = false;
self
}
pub fn auto_increment(mut self) -> Self {
self.auto_increment = true;
self
}
pub fn not_null(mut self) -> Self {
self.nullable = false;
self
}
pub fn default(mut self, expr: impl Into<String>) -> Self {
self.default = Some(expr.into());
self
}
}
#[derive(Debug, Clone)]
pub struct ModelSchema {
pub table_name: String,
pub schema_name: String,
pub columns: Vec<ColumnDef>,
}
impl ModelSchema {
pub fn new(table_name: impl Into<String>) -> Self {
Self {
table_name: table_name.into(),
schema_name: "public".to_string(),
columns: Vec::new(),
}
}
pub fn schema(mut self, schema: impl Into<String>) -> Self {
self.schema_name = schema.into();
self
}
pub fn column(mut self, col: ColumnDef) -> Self {
self.columns.push(col);
self
}
pub fn columns(mut self, cols: Vec<ColumnDef>) -> Self {
self.columns.extend(cols);
self
}
}
pub async fn sync_database(db: &Database) -> Result<()> {
sync_database_with_options(db, false).await
}
pub async fn sync_database_with_options(db: &Database, force_sync: bool) -> Result<()> {
if force_sync {
tide_warn!("Database FORCE sync mode is ENABLED - using SeaORM apply mode!");
} else {
tide_warn!("Database sync mode is ENABLED - DO NOT use in production!");
}
let conn = db.__internal_connection();
let backend = conn.get_database_backend();
let entity_count = SyncRegistry::entity_count();
let legacy_count = SyncRegistry::legacy_count();
let total_count = entity_count + legacy_count;
if total_count == 0 {
tide_info!("No models registered for sync");
return Ok(());
}
tide_info!("Syncing {} model(s) using SeaORM SchemaBuilder...", total_count);
tide_debug!(" - {} entity-based models", entity_count);
tide_debug!(" - {} legacy schema models", legacy_count);
if entity_count > 0 {
let schema_builder = SyncRegistry::build_schema_builder(backend);
if force_sync {
tide_debug!(" Using SeaORM SchemaBuilder.apply() - fresh schema creation");
schema_builder.apply(conn).await
.map_err(|e| Error::query(format!("Schema apply failed: {}", e)))?;
} else {
tide_debug!(" Using SeaORM SchemaBuilder.sync() - incremental sync");
schema_builder.sync(conn).await
.map_err(|e| Error::query(format!("Schema sync failed: {}", e)))?;
}
}
if legacy_count > 0 {
tide_debug!(" Processing {} legacy schema(s)...", legacy_count);
sync_legacy_schemas(db, force_sync).await?;
}
tide_info!("Database sync completed using SeaORM");
Ok(())
}
async fn sync_legacy_schemas(db: &Database, force_sync: bool) -> Result<()> {
let models = SyncRegistry::get_all();
let conn = db.__internal_connection();
let backend = conn.get_database_backend();
for model in models {
let table_exists = check_table_exists(conn, &model.schema_name, &model.table_name, backend).await?;
if force_sync && table_exists {
let drop_sql = match backend {
DbBackend::Postgres => format!("DROP TABLE IF EXISTS \"{}\".\"{}\" CASCADE", model.schema_name, model.table_name),
DbBackend::MySql => format!("DROP TABLE IF EXISTS `{}`", model.table_name),
DbBackend::Sqlite => format!("DROP TABLE IF EXISTS \"{}\"", model.table_name),
_ => format!("DROP TABLE IF EXISTS \"{}\"", model.table_name),
};
let drop_stmt = Statement::from_string(backend, drop_sql);
conn.execute_raw(drop_stmt)
.await
.map_err(|e| Error::query(e.to_string()))?;
tide_warn!("Dropped legacy table: {}", model.table_name);
}
if !table_exists || force_sync {
create_table_from_legacy_schema(conn, &model, backend).await?;
tide_info!("Created legacy table: {}", model.table_name);
} else {
tide_debug!("Legacy table exists: {}", model.table_name);
}
}
Ok(())
}
async fn check_table_exists(
conn: &sea_orm::DatabaseConnection,
schema: &str,
table: &str,
backend: DbBackend,
) -> Result<bool> {
let sql = match backend {
DbBackend::Postgres => format!(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = '{}' AND table_name = '{}')",
schema, table
),
DbBackend::MySql => format!(
"SELECT COUNT(*) > 0 FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '{}'",
table
),
DbBackend::Sqlite => format!(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type = 'table' AND name = '{}'",
table
),
other => {
tide_warn!("Unknown backend {:?} in check_table_exists, falling back to SQLite SQL", other);
format!(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type = 'table' AND name = '{}'",
table
)
}
};
let stmt = Statement::from_string(backend, sql);
let result = conn
.query_one_raw(stmt)
.await
.map_err(|e| Error::query(e.to_string()))?;
match result {
Some(row) => {
let exists: bool = match backend {
DbBackend::Postgres => row.try_get_by_index(0).unwrap_or(false),
_ => {
let val: i32 = row.try_get_by_index(0).unwrap_or(0);
val > 0
}
};
Ok(exists)
}
None => Ok(false),
}
}
async fn create_table_from_legacy_schema(
conn: &sea_orm::DatabaseConnection,
model: &ModelSchema,
backend: DbBackend,
) -> Result<()> {
let mut table = Table::create();
table.table(Alias::new(&model.table_name));
for col in &model.columns {
let mut column = SeaColumnDef::new(Alias::new(&col.name));
apply_column_type(&mut column, &col.col_type, col.auto_increment, backend);
if col.primary_key {
column.primary_key();
}
if col.auto_increment {
column.auto_increment();
}
if !col.nullable && !col.primary_key && !col.auto_increment {
column.not_null();
}
if let Some(ref default) = col.default {
let default_owned = default.clone();
column.default(Expr::cust(default_owned));
}
table.col(&mut column);
}
table.if_not_exists();
let sql = match backend {
DbBackend::Postgres => table.to_string(PostgresQueryBuilder),
DbBackend::MySql => table.to_string(MysqlQueryBuilder),
DbBackend::Sqlite => table.to_string(SqliteQueryBuilder),
_ => table.to_string(PostgresQueryBuilder),
};
let create_stmt = Statement::from_string(backend, sql);
conn.execute_raw(create_stmt)
.await
.map_err(|e| Error::query(e.to_string()))?;
Ok(())
}
fn apply_column_type(
column: &mut sea_orm::sea_query::ColumnDef,
rust_type: &str,
_auto_increment: bool,
_backend: DbBackend
) {
let normalized = normalize_rust_type(rust_type);
let inner_type = normalized
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix(">"))
.unwrap_or(&normalized);
match inner_type {
"i8" | "u8" | "i16" | "u16" => { column.small_integer(); }
"i32" => { column.integer(); }
"u32" | "i64" => { column.big_integer(); }
"u64" | "i128" | "u128" => { column.decimal(); }
"isize" | "usize" => { column.big_integer(); }
"f32" => { column.float(); }
"f64" => { column.double(); }
"bool" => { column.boolean(); }
"String" | "&str" => { column.text(); }
"Uuid" => { column.uuid(); }
"Json" | "JsonValue" | "serde_json::Value" | "Value" | "Jsonb" => { column.json_binary(); }
"Vec<u8>" | "Bytes" => { column.binary(); }
"Decimal" | "BigDecimal" => { column.decimal(); }
t if t.contains("DateTime") => { column.timestamp_with_time_zone(); }
t if t.contains("NaiveDateTime") => { column.timestamp(); }
t if t.contains("NaiveDate") => { column.date(); }
t if t.contains("NaiveTime") => { column.time(); }
"Vec<i32>" | "IntArray" => { column.array(SeaColumnType::Integer); }
"Vec<i64>" | "BigIntArray" => { column.array(SeaColumnType::BigInteger); }
"Vec<String>" | "TextArray" => { column.array(SeaColumnType::Text); }
"Vec<bool>" | "BoolArray" => { column.array(SeaColumnType::Boolean); }
"Vec<f64>" | "FloatArray" => { column.array(SeaColumnType::Double); }
unknown_type => {
tide_warn!(
"Unknown Rust type '{}' mapped to TEXT column. Consider adding explicit type mapping.",
unknown_type
);
column.text();
}
};
}
pub fn normalize_rust_type(rust_type: &str) -> String {
rust_type.chars().filter(|c| !c.is_whitespace()).collect()
}