use std::sync::OnceLock;
use parking_lot::RwLock;
use crate::database::Database;
use crate::error::{Error, Result};
use sea_orm::{
ConnectionTrait, DbBackend, EntityTrait, Statement,
schema::{Schema, SchemaBuilder},
sea_query::{
Table, ColumnDef as SeaColumnDef, Alias, Expr,
ColumnType as SeaColumnType, PostgresQueryBuilder,
MysqlQueryBuilder, SqliteQueryBuilder,
},
};
pub type EntityRegistrationFn = Box<dyn Fn(SchemaBuilder) -> SchemaBuilder + Send + Sync>;
static ENTITY_REGISTRY: OnceLock<RwLock<Vec<EntityRegistrationFn>>> = OnceLock::new();
static DIRECT_SCHEMAS: OnceLock<RwLock<Vec<ModelSchema>>> = OnceLock::new();
fn get_entity_registry() -> &'static RwLock<Vec<EntityRegistrationFn>> {
ENTITY_REGISTRY.get_or_init(|| RwLock::new(Vec::new()))
}
fn get_direct_schemas() -> &'static RwLock<Vec<ModelSchema>> {
DIRECT_SCHEMAS.get_or_init(|| RwLock::new(Vec::new()))
}
pub trait SyncModel {
fn sync_schema() -> ModelSchema;
fn register_for_sync() {
SyncRegistry::register(Self::sync_schema());
}
}
pub trait RegisterModels {
fn register_all();
}
impl RegisterModels for () {
fn register_all() {}
}
impl<A: SyncModel> RegisterModels for (A,) {
fn register_all() {
A::register_for_sync();
}
}
impl<A: SyncModel, B: SyncModel> RegisterModels for (A, B) {
fn register_all() {
A::register_for_sync();
B::register_for_sync();
}
}
impl<A: SyncModel, B: SyncModel, C: SyncModel> RegisterModels for (A, B, C) {
fn register_all() {
A::register_for_sync();
B::register_for_sync();
C::register_for_sync();
}
}
impl<A: SyncModel, B: SyncModel, C: SyncModel, D: SyncModel> RegisterModels for (A, B, C, D) {
fn register_all() {
A::register_for_sync();
B::register_for_sync();
C::register_for_sync();
D::register_for_sync();
}
}
impl<A: SyncModel, B: SyncModel, C: SyncModel, D: SyncModel, E: SyncModel> RegisterModels for (A, B, C, D, E) {
fn register_all() {
A::register_for_sync();
B::register_for_sync();
C::register_for_sync();
D::register_for_sync();
E::register_for_sync();
}
}
impl<A: SyncModel, B: SyncModel, C: SyncModel, D: SyncModel, E: SyncModel, F: SyncModel> RegisterModels for (A, B, C, D, E, F) {
fn register_all() {
A::register_for_sync();
B::register_for_sync();
C::register_for_sync();
D::register_for_sync();
E::register_for_sync();
F::register_for_sync();
}
}
impl<A: SyncModel, B: SyncModel, C: SyncModel, D: SyncModel, E: SyncModel, F: SyncModel, G: SyncModel> RegisterModels for (A, B, C, D, E, F, G) {
fn register_all() {
A::register_for_sync();
B::register_for_sync();
C::register_for_sync();
D::register_for_sync();
E::register_for_sync();
F::register_for_sync();
G::register_for_sync();
}
}
impl<A: SyncModel, B: SyncModel, C: SyncModel, D: SyncModel, E: SyncModel, F: SyncModel, G: SyncModel, H: SyncModel> RegisterModels for (A, B, C, D, E, F, G, H) {
fn register_all() {
A::register_for_sync();
B::register_for_sync();
C::register_for_sync();
D::register_for_sync();
E::register_for_sync();
F::register_for_sync();
G::register_for_sync();
H::register_for_sync();
}
}
impl<A: SyncModel, B: SyncModel, C: SyncModel, D: SyncModel, E: SyncModel, F: SyncModel, G: SyncModel, H: SyncModel, I: SyncModel> RegisterModels for (A, B, C, D, E, F, G, H, I) {
fn register_all() {
A::register_for_sync();
B::register_for_sync();
C::register_for_sync();
D::register_for_sync();
E::register_for_sync();
F::register_for_sync();
G::register_for_sync();
H::register_for_sync();
I::register_for_sync();
}
}
impl<A: SyncModel, B: SyncModel, C: SyncModel, D: SyncModel, E: SyncModel, F: SyncModel, G: SyncModel, H: SyncModel, I: SyncModel, J: SyncModel> RegisterModels for (A, B, C, D, E, F, G, H, I, J) {
fn register_all() {
A::register_for_sync();
B::register_for_sync();
C::register_for_sync();
D::register_for_sync();
E::register_for_sync();
F::register_for_sync();
G::register_for_sync();
H::register_for_sync();
I::register_for_sync();
J::register_for_sync();
}
}
impl<A: SyncModel, B: SyncModel, C: SyncModel, D: SyncModel, E: SyncModel, F: SyncModel, G: SyncModel, H: SyncModel, I: SyncModel, J: SyncModel, K: SyncModel> RegisterModels for (A, B, C, D, E, F, G, H, I, J, K) {
fn register_all() {
A::register_for_sync();
B::register_for_sync();
C::register_for_sync();
D::register_for_sync();
E::register_for_sync();
F::register_for_sync();
G::register_for_sync();
H::register_for_sync();
I::register_for_sync();
J::register_for_sync();
K::register_for_sync();
}
}
impl<A: SyncModel, B: SyncModel, C: SyncModel, D: SyncModel, E: SyncModel, F: SyncModel, G: SyncModel, H: SyncModel, I: SyncModel, J: SyncModel, K: SyncModel, L: SyncModel> RegisterModels for (A, B, C, D, E, F, G, H, I, J, K, L) {
fn register_all() {
A::register_for_sync();
B::register_for_sync();
C::register_for_sync();
D::register_for_sync();
E::register_for_sync();
F::register_for_sync();
G::register_for_sync();
H::register_for_sync();
I::register_for_sync();
J::register_for_sync();
K::register_for_sync();
L::register_for_sync();
}
}
pub struct SyncRegistry;
impl SyncRegistry {
pub fn register_entity<E: EntityTrait + Default + 'static>() {
let registry = get_entity_registry();
let mut fns = registry.write();
let register_fn: EntityRegistrationFn = Box::new(|builder: SchemaBuilder| {
builder.register(E::default())
});
fns.push(register_fn);
}
pub fn build_schema_builder(backend: DbBackend) -> SchemaBuilder {
let registry = get_entity_registry();
let fns = registry.read();
let schema = Schema::new(backend);
let mut builder = schema.builder();
for register_fn in fns.iter() {
builder = register_fn(builder);
}
builder
}
pub fn entity_count() -> usize {
let registry = get_entity_registry();
let fns = registry.read();
fns.len()
}
pub fn legacy_count() -> usize {
let direct = get_direct_schemas();
let schemas = direct.read();
schemas.len()
}
pub fn clear() {
let registry = get_entity_registry();
let mut fns = registry.write();
fns.clear();
let direct = get_direct_schemas();
let mut schemas = direct.write();
schemas.clear();
}
pub fn register(schema: ModelSchema) {
let direct = get_direct_schemas();
let mut schemas = direct.write();
if !schemas.iter().any(|s| s.table_name == schema.table_name) {
schemas.push(schema);
}
}
pub fn get_all() -> Vec<ModelSchema> {
let direct = get_direct_schemas();
let schemas = direct.read();
schemas.clone()
}
}
#[derive(Debug, Clone)]
pub struct ColumnDef {
pub name: String,
pub col_type: String,
pub nullable: bool,
pub primary_key: bool,
pub auto_increment: bool,
pub default: Option<String>,
}
impl ColumnDef {
pub fn new(name: impl Into<String>, col_type: impl Into<String>) -> Self {
Self {
name: name.into(),
col_type: col_type.into(),
nullable: true,
primary_key: false,
auto_increment: false,
default: None,
}
}
pub fn primary_key(mut self) -> Self {
self.primary_key = true;
self.nullable = false;
self
}
pub fn auto_increment(mut self) -> Self {
self.auto_increment = true;
self
}
pub fn not_null(mut self) -> Self {
self.nullable = false;
self
}
pub fn default(mut self, expr: impl Into<String>) -> Self {
self.default = Some(expr.into());
self
}
}
#[derive(Debug, Clone)]
pub struct ModelSchema {
pub table_name: String,
pub schema_name: String,
pub columns: Vec<ColumnDef>,
}
impl ModelSchema {
pub fn new(table_name: impl Into<String>) -> Self {
Self {
table_name: table_name.into(),
schema_name: "public".to_string(),
columns: Vec::new(),
}
}
pub fn schema(mut self, schema: impl Into<String>) -> Self {
self.schema_name = schema.into();
self
}
pub fn column(mut self, col: ColumnDef) -> Self {
self.columns.push(col);
self
}
pub fn columns(mut self, cols: Vec<ColumnDef>) -> Self {
self.columns.extend(cols);
self
}
}
pub async fn sync_database(db: &Database) -> Result<()> {
sync_database_with_options(db, false).await
}
pub async fn sync_database_with_options(db: &Database, force_sync: bool) -> Result<()> {
if force_sync {
eprintln!("⚠️ Database FORCE sync mode is ENABLED - using SeaORM apply mode!");
} else {
eprintln!("⚠️ Database sync mode is ENABLED - DO NOT use in production!");
}
let conn = db.__internal_connection();
let backend = conn.get_database_backend();
let entity_count = SyncRegistry::entity_count();
let legacy_count = SyncRegistry::legacy_count();
let total_count = entity_count + legacy_count;
if total_count == 0 {
eprintln!("No models registered for sync");
return Ok(());
}
eprintln!("Syncing {} model(s) using SeaORM SchemaBuilder...", total_count);
eprintln!(" - {} entity-based models", entity_count);
eprintln!(" - {} legacy schema models", legacy_count);
if entity_count > 0 {
let schema_builder = SyncRegistry::build_schema_builder(backend);
if force_sync {
eprintln!(" Using SeaORM SchemaBuilder.apply() - fresh schema creation");
schema_builder.apply(conn).await
.map_err(|e| Error::query(format!("Schema apply failed: {}", e)))?;
} else {
eprintln!(" Using SeaORM SchemaBuilder.sync() - incremental sync");
schema_builder.sync(conn).await
.map_err(|e| Error::query(format!("Schema sync failed: {}", e)))?;
}
}
if legacy_count > 0 {
eprintln!(" Processing {} legacy schema(s)...", legacy_count);
sync_legacy_schemas(db, force_sync).await?;
}
eprintln!(" Database sync completed using SeaORM");
Ok(())
}
async fn sync_legacy_schemas(db: &Database, force_sync: bool) -> Result<()> {
let models = SyncRegistry::get_all();
let conn = db.__internal_connection();
let backend = conn.get_database_backend();
for model in models {
let table_exists = check_table_exists(conn, &model.schema_name, &model.table_name, backend).await?;
if force_sync && table_exists {
let drop_sql = match backend {
DbBackend::Postgres => format!("DROP TABLE IF EXISTS \"{}\".\"{}\" CASCADE", model.schema_name, model.table_name),
DbBackend::MySql => format!("DROP TABLE IF EXISTS `{}`", model.table_name),
DbBackend::Sqlite => format!("DROP TABLE IF EXISTS \"{}\"", model.table_name),
_ => format!("DROP TABLE IF EXISTS \"{}\"", model.table_name),
};
let drop_stmt = Statement::from_string(backend, drop_sql);
conn.execute_raw(drop_stmt)
.await
.map_err(|e| Error::query(e.to_string()))?;
eprintln!(" ⚠️ Dropped legacy table: {}", model.table_name);
}
if !table_exists || force_sync {
create_table_from_legacy_schema(conn, &model, backend).await?;
eprintln!(" Created legacy table: {}", model.table_name);
} else {
eprintln!(" Legacy table exists: {}", model.table_name);
}
}
Ok(())
}
async fn check_table_exists(
conn: &sea_orm::DatabaseConnection,
schema: &str,
table: &str,
backend: DbBackend,
) -> Result<bool> {
let sql = match backend {
DbBackend::Postgres => format!(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = '{}' AND table_name = '{}')",
schema, table
),
DbBackend::MySql => format!(
"SELECT COUNT(*) > 0 FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '{}'",
table
),
DbBackend::Sqlite => format!(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type = 'table' AND name = '{}'",
table
),
_ => format!(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type = 'table' AND name = '{}'",
table
),
};
let stmt = Statement::from_string(backend, sql);
let result = conn
.query_one_raw(stmt)
.await
.map_err(|e| Error::query(e.to_string()))?;
match result {
Some(row) => {
let exists: bool = match backend {
DbBackend::Postgres => row.try_get_by_index(0).unwrap_or(false),
DbBackend::MySql | DbBackend::Sqlite => {
let val: i32 = row.try_get_by_index(0).unwrap_or(0);
val > 0
}
_ => {
let val: i32 = row.try_get_by_index(0).unwrap_or(0);
val > 0
}
};
Ok(exists)
}
None => Ok(false),
}
}
async fn create_table_from_legacy_schema(
conn: &sea_orm::DatabaseConnection,
model: &ModelSchema,
backend: DbBackend,
) -> Result<()> {
let mut table = Table::create();
table.table(Alias::new(&model.table_name));
for col in &model.columns {
let mut column = SeaColumnDef::new(Alias::new(&col.name));
apply_column_type(&mut column, &col.col_type, col.auto_increment, backend);
if col.primary_key {
column.primary_key();
}
if col.auto_increment {
column.auto_increment();
}
if !col.nullable && !col.primary_key && !col.auto_increment {
column.not_null();
}
if let Some(ref default) = col.default {
let default_owned = default.clone();
column.default(Expr::cust(default_owned));
}
table.col(&mut column);
}
table.if_not_exists();
let sql = match backend {
DbBackend::Postgres => table.to_string(PostgresQueryBuilder),
DbBackend::MySql => table.to_string(MysqlQueryBuilder),
DbBackend::Sqlite => table.to_string(SqliteQueryBuilder),
_ => table.to_string(PostgresQueryBuilder),
};
let create_stmt = Statement::from_string(backend, sql);
conn.execute_raw(create_stmt)
.await
.map_err(|e| Error::query(e.to_string()))?;
Ok(())
}
fn apply_column_type(
column: &mut sea_orm::sea_query::ColumnDef,
rust_type: &str,
_auto_increment: bool,
_backend: DbBackend
) {
let normalized = normalize_rust_type(rust_type);
let inner_type = normalized
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix(">"))
.unwrap_or(&normalized);
match inner_type {
"i8" | "u8" | "i16" | "u16" => { column.small_integer(); }
"i32" => { column.integer(); }
"u32" | "i64" => { column.big_integer(); }
"u64" | "i128" | "u128" => { column.decimal(); }
"isize" | "usize" => { column.big_integer(); }
"f32" => { column.float(); }
"f64" => { column.double(); }
"bool" => { column.boolean(); }
"String" | "&str" => { column.text(); }
"Uuid" => { column.uuid(); }
"Json" | "JsonValue" | "serde_json::Value" | "Value" | "Jsonb" => { column.json_binary(); }
"Vec<u8>" | "Bytes" => { column.binary(); }
"Decimal" | "BigDecimal" => { column.decimal(); }
t if t.contains("DateTime") => { column.timestamp_with_time_zone(); }
t if t.contains("NaiveDateTime") => { column.timestamp(); }
t if t.contains("NaiveDate") => { column.date(); }
t if t.contains("NaiveTime") => { column.time(); }
"Vec<i32>" | "IntArray" => { column.array(SeaColumnType::Integer); }
"Vec<i64>" | "BigIntArray" => { column.array(SeaColumnType::BigInteger); }
"Vec<String>" | "TextArray" => { column.array(SeaColumnType::Text); }
"Vec<bool>" | "BoolArray" => { column.array(SeaColumnType::Boolean); }
"Vec<f64>" | "FloatArray" => { column.array(SeaColumnType::Double); }
_ => { column.text(); }
};
}
pub fn normalize_rust_type(rust_type: &str) -> String {
rust_type.chars().filter(|c| !c.is_whitespace()).collect()
}