use parking_lot::RwLock;
use std::sync::OnceLock;
use crate::database::Database;
use crate::error::{Error, Result};
use crate::{tide_debug, tide_info, tide_warn};
use sea_orm::{
ConnectionTrait, DbBackend, EntityTrait, Statement,
schema::{Schema, SchemaBuilder},
sea_query::{
Alias, ColumnDef as SeaColumnDef, ColumnType as SeaColumnType, Expr, Index,
MysqlQueryBuilder, PostgresQueryBuilder, SqliteQueryBuilder, Table,
},
};
pub type EntityRegistrationFn = Box<dyn Fn(SchemaBuilder) -> SchemaBuilder + Send + Sync>;
static ENTITY_REGISTRY: OnceLock<RwLock<Vec<EntityRegistrationFn>>> = OnceLock::new();
static MODEL_SCHEMAS: OnceLock<RwLock<Vec<ModelSchema>>> = OnceLock::new();
fn get_entity_registry() -> &'static RwLock<Vec<EntityRegistrationFn>> {
ENTITY_REGISTRY.get_or_init(|| RwLock::new(Vec::new()))
}
fn get_model_schemas() -> &'static RwLock<Vec<ModelSchema>> {
MODEL_SCHEMAS.get_or_init(|| RwLock::new(Vec::new()))
}
pub trait SyncModel {
fn sync_schema() -> ModelSchema;
fn register_for_sync() {
SyncRegistry::register_schema(Self::sync_schema());
}
}
pub trait RegisterModels {
fn register_all();
}
impl RegisterModels for () {
fn register_all() {}
}
macro_rules! impl_register_models_tuples {
($first:ident) => {
impl<$first: SyncModel> RegisterModels for ($first,) {
fn register_all() {
$first::register_for_sync();
}
}
};
($first:ident, $($rest:ident),+) => {
impl_register_models_tuples!($($rest),+);
impl<$first: SyncModel, $($rest: SyncModel),+> RegisterModels for ($first, $($rest),+) {
fn register_all() {
$first::register_for_sync();
$($rest::register_for_sync();)+
}
}
};
}
impl_register_models_tuples!(
T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16, T17, T18, T19, T20, T21,
T22, T23, T24, T25, T26, T27, T28, T29, T30, T31, T32, T33, T34, T35, T36, T37, T38, T39, T40,
T41, T42, T43, T44, T45, T46, T47, T48, T49, T50, T51, T52, T53, T54, T55, T56, T57, T58, T59,
T60, T61, T62, T63, T64, T65, T66, T67, T68, T69, T70, T71, T72, T73, T74, T75, T76, T77, T78,
T79, T80, T81, T82, T83, T84, T85, T86, T87, T88, T89, T90, T91, T92, T93, T94, T95, T96, T97,
T98, T99, T100, T101, T102, T103, T104, T105, T106, T107, T108, T109, T110, T111, T112, T113,
T114, T115, T116, T117, T118, T119, T120, T121, T122, T123, T124, T125, T126, T127, T128, T129,
T130, T131, T132, T133, T134, T135, T136, T137, T138, T139, T140, T141, T142, T143, T144, T145,
T146, T147, T148, T149, T150, T151, T152, T153, T154, T155, T156, T157, T158, T159, T160, T161,
T162, T163, T164, T165, T166, T167, T168, T169, T170, T171, T172, T173, T174, T175, T176, T177,
T178, T179, T180, T181, T182, T183, T184, T185, T186, T187, T188, T189, T190, T191, T192, T193,
T194, T195, T196, T197, T198, T199, T200
);
pub struct SyncRegistry;
impl SyncRegistry {
pub fn register_entity<E: EntityTrait + Default + 'static>() {
let registry = get_entity_registry();
let mut fns = registry.write();
let register_fn: EntityRegistrationFn =
Box::new(|builder: SchemaBuilder| builder.register(E::default()));
fns.push(register_fn);
}
pub fn build_schema_builder(backend: DbBackend) -> SchemaBuilder {
let registry = get_entity_registry();
let fns = registry.read();
let schema = Schema::new(backend);
let mut builder = schema.builder();
for register_fn in fns.iter() {
builder = register_fn(builder);
}
builder
}
pub fn entity_count() -> usize {
let registry = get_entity_registry();
let fns = registry.read();
fns.len()
}
pub fn schema_count() -> usize {
let direct = get_model_schemas();
let schemas = direct.read();
schemas.len()
}
pub fn clear() {
let registry = get_entity_registry();
let mut fns = registry.write();
fns.clear();
let direct = get_model_schemas();
let mut schemas = direct.write();
schemas.clear();
}
pub fn register_schema(schema: ModelSchema) {
let direct = get_model_schemas();
let mut schemas = direct.write();
if !schemas.iter().any(|s| s.table_name == schema.table_name) {
schemas.push(schema);
}
}
pub fn get_all_schemas() -> Vec<ModelSchema> {
let direct = get_model_schemas();
let schemas = direct.read();
schemas.clone()
}
}
#[derive(Debug, Clone)]
pub struct ColumnDef {
pub name: String,
pub col_type: String,
pub nullable: bool,
pub primary_key: bool,
pub auto_increment: bool,
pub default: Option<String>,
}
impl ColumnDef {
pub fn new(name: impl Into<String>, col_type: impl Into<String>) -> Self {
Self {
name: name.into(),
col_type: col_type.into(),
nullable: true,
primary_key: false,
auto_increment: false,
default: None,
}
}
pub fn primary_key(mut self) -> Self {
self.primary_key = true;
self.nullable = false;
self
}
pub fn auto_increment(mut self) -> Self {
self.auto_increment = true;
self
}
pub fn not_null(mut self) -> Self {
self.nullable = false;
self
}
pub fn default(mut self, expr: impl Into<String>) -> Self {
self.default = Some(expr.into());
self
}
}
#[derive(Debug, Clone)]
pub struct ModelSchema {
pub table_name: String,
pub schema_name: String,
pub columns: Vec<ColumnDef>,
pub primary_keys: Vec<String>,
}
impl ModelSchema {
pub fn new(table_name: impl Into<String>) -> Self {
Self {
table_name: table_name.into(),
schema_name: "public".to_string(),
columns: Vec::new(),
primary_keys: Vec::new(),
}
}
pub fn schema(mut self, schema: impl Into<String>) -> Self {
self.schema_name = schema.into();
self
}
pub fn column(mut self, col: ColumnDef) -> Self {
self.columns.push(col);
self
}
pub fn columns(mut self, cols: Vec<ColumnDef>) -> Self {
self.columns.extend(cols);
self
}
pub fn primary_keys(mut self, columns: Vec<String>) -> Self {
self.primary_keys = columns;
self
}
}
pub async fn sync_database(db: &Database) -> Result<()> {
sync_database_with_options(db, false).await
}
pub async fn sync_database_with_options(db: &Database, force_sync: bool) -> Result<()> {
if force_sync {
tide_warn!("Database FORCE sync mode is ENABLED - using SeaORM apply mode!");
} else {
tide_warn!("Database sync mode is ENABLED - DO NOT use in production!");
}
let conn = db.__internal_connection()?;
let backend = conn.get_database_backend();
let entity_count = SyncRegistry::entity_count();
let schema_count = SyncRegistry::schema_count();
let total_count = entity_count + schema_count;
if total_count == 0 {
tide_info!("No models registered for sync");
return Ok(());
}
tide_info!(
"Syncing {} model(s) using SeaORM SchemaBuilder...",
total_count
);
tide_debug!(" - {} entity-based models", entity_count);
tide_debug!(" - {} TideORM schema models", schema_count);
if entity_count > 0 {
let schema_builder = SyncRegistry::build_schema_builder(backend);
#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
if force_sync {
tide_debug!(" Using SeaORM SchemaBuilder.apply() - fresh schema creation");
schema_builder
.apply(&conn)
.await
.map_err(|e| Error::query(format!("Schema apply failed: {}", e)))?;
} else {
tide_debug!(" Using SeaORM SchemaBuilder.sync() - incremental sync");
schema_builder
.sync(&conn)
.await
.map_err(|e| Error::query(format!("Schema sync failed: {}", e)))?;
}
#[cfg(not(any(feature = "postgres", feature = "mysql", feature = "sqlite")))]
{
let _ = schema_builder;
return Err(Error::configuration(
"database sync requires at least one backend feature: postgres, mysql, or sqlite",
));
}
}
if schema_count > 0 {
tide_debug!(" Processing {} TideORM schema(s)...", schema_count);
sync_model_schemas(db, force_sync).await?;
}
tide_info!("Database sync completed using SeaORM");
Ok(())
}
async fn sync_model_schemas(db: &Database, force_sync: bool) -> Result<()> {
let models = SyncRegistry::get_all_schemas();
let conn = db.__internal_connection()?;
let backend = conn.get_database_backend();
for model in models {
let table_exists =
check_table_exists(&conn, &model.schema_name, &model.table_name, backend).await?;
if force_sync && table_exists {
let drop_sql = match backend {
DbBackend::Postgres => format!(
"DROP TABLE IF EXISTS \"{}\".\"{}\" CASCADE",
model.schema_name, model.table_name
),
DbBackend::MySql => format!("DROP TABLE IF EXISTS `{}`", model.table_name),
DbBackend::Sqlite => format!("DROP TABLE IF EXISTS \"{}\"", model.table_name),
_ => format!("DROP TABLE IF EXISTS \"{}\"", model.table_name),
};
let drop_stmt = Statement::from_string(backend, drop_sql);
conn.execute_raw(drop_stmt)
.await
.map_err(|e| Error::query(e.to_string()))?;
tide_warn!("Dropped TideORM table: {}", model.table_name);
}
if !table_exists || force_sync {
create_table_from_model_schema(&conn, &model, backend).await?;
tide_info!("Created TideORM table: {}", model.table_name);
} else {
tide_debug!("TideORM table exists: {}", model.table_name);
}
}
Ok(())
}
async fn check_table_exists(
conn: &sea_orm::DatabaseConnection,
schema: &str,
table: &str,
backend: DbBackend,
) -> Result<bool> {
let sql = match backend {
DbBackend::Postgres => format!(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = '{}' AND table_name = '{}')",
schema, table
),
DbBackend::MySql => format!(
"SELECT COUNT(*) > 0 FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = '{}'",
table
),
DbBackend::Sqlite => format!(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type = 'table' AND name = '{}'",
table
),
other => {
tide_warn!(
"Unknown backend {:?} in check_table_exists, falling back to SQLite SQL",
other
);
format!(
"SELECT COUNT(*) > 0 FROM sqlite_master WHERE type = 'table' AND name = '{}'",
table
)
}
};
let stmt = Statement::from_string(backend, sql);
let result = conn
.query_one_raw(stmt)
.await
.map_err(|e| Error::query(e.to_string()))?;
match result {
Some(row) => {
let exists: bool = match backend {
DbBackend::Postgres => row.try_get_by_index(0).unwrap_or(false),
_ => {
let val: i32 = row.try_get_by_index(0).unwrap_or(0);
val > 0
}
};
Ok(exists)
}
None => Ok(false),
}
}
async fn create_table_from_model_schema(
conn: &sea_orm::DatabaseConnection,
model: &ModelSchema,
backend: DbBackend,
) -> Result<()> {
let mut table = Table::create();
table.table(Alias::new(&model.table_name));
let composite_primary_key = model.primary_keys.len() > 1;
for col in &model.columns {
let mut column = SeaColumnDef::new(Alias::new(&col.name));
apply_column_type(&mut column, &col.col_type, col.auto_increment, backend);
if col.primary_key && !composite_primary_key {
column.primary_key();
}
if col.auto_increment {
column.auto_increment();
}
if (composite_primary_key || !col.primary_key) && !col.auto_increment && !col.nullable {
column.not_null();
}
if let Some(ref default) = col.default {
let default_owned = default.clone();
column.default(Expr::cust(default_owned));
}
table.col(&mut column);
}
if composite_primary_key {
let mut primary_key = Index::create();
for column in &model.primary_keys {
primary_key.col(Alias::new(column));
}
table.primary_key(&mut primary_key);
}
table.if_not_exists();
let sql = match backend {
DbBackend::Postgres => table.to_string(PostgresQueryBuilder),
DbBackend::MySql => table.to_string(MysqlQueryBuilder),
DbBackend::Sqlite => table.to_string(SqliteQueryBuilder),
_ => table.to_string(PostgresQueryBuilder),
};
let create_stmt = Statement::from_string(backend, sql);
conn.execute_raw(create_stmt)
.await
.map_err(|e| Error::query(e.to_string()))?;
Ok(())
}
fn apply_column_type(
column: &mut sea_orm::sea_query::ColumnDef,
rust_type: &str,
_auto_increment: bool,
_backend: DbBackend,
) {
let normalized = normalize_rust_type(rust_type);
let inner_type = normalized
.strip_prefix("Option<")
.and_then(|s| s.strip_suffix(">"))
.unwrap_or(&normalized);
let inner_type = canonical_schema_type(inner_type);
match inner_type.as_str() {
"i8" | "u8" | "i16" | "u16" => {
column.small_integer();
}
"i32" => {
column.integer();
}
"u32" | "i64" => {
column.big_integer();
}
"u64" | "i128" | "u128" => {
column.decimal();
}
"isize" | "usize" => {
column.big_integer();
}
"f32" => {
column.float();
}
"f64" => {
column.double();
}
"bool" => {
column.boolean();
}
"String" | "&str" => {
column.text();
}
"Uuid" => {
column.uuid();
}
"Json" | "JsonValue" | "serde_json::Value" | "Value" | "Jsonb" => {
column.json_binary();
}
"Vec<u8>" | "Bytes" => {
column.binary();
}
"Decimal" | "BigDecimal" => {
column.decimal();
}
t if t.contains("DateTime") => {
column.timestamp_with_time_zone();
}
t if t.contains("NaiveDateTime") => {
column.timestamp();
}
t if t.contains("NaiveDate") => {
column.date();
}
t if t.contains("NaiveTime") => {
column.time();
}
"Vec<i32>" | "IntArray" => {
column.array(SeaColumnType::Integer);
}
"Vec<i64>" | "BigIntArray" => {
column.array(SeaColumnType::BigInteger);
}
"Vec<String>" | "TextArray" => {
column.array(SeaColumnType::Text);
}
"Vec<bool>" | "BoolArray" => {
column.array(SeaColumnType::Boolean);
}
"Vec<f64>" | "FloatArray" => {
column.array(SeaColumnType::Double);
}
unknown_type => {
tide_warn!(
"Unknown Rust type '{}' mapped to TEXT column. Consider adding explicit type mapping.",
unknown_type
);
column.text();
}
};
}
fn canonical_schema_type(rust_type: &str) -> String {
let normalized = rust_type.trim();
for alias in [
"Json",
"JsonValue",
"JsonArray",
"Jsonb",
"IntArray",
"BigIntArray",
"TextArray",
"BoolArray",
"FloatArray",
"Decimal",
"Uuid",
"NaiveDate",
"NaiveTime",
"NaiveDateTime",
"Text",
] {
if normalized == alias || normalized.ends_with(&format!("::{}", alias)) {
return alias.to_string();
}
}
normalized.to_string()
}
pub fn normalize_rust_type(rust_type: &str) -> String {
rust_type.chars().filter(|c| !c.is_whitespace()).collect()
}