use std::collections::HashMap;
use reinhardt_db::backends::DatabaseConnection;
use reinhardt_db::migrations::{
ColumnDefinition, Constraint, ForeignKeyAction, Migration, Operation,
executor::DatabaseMigrationExecutor, field_type_string_to_field_type, to_snake_case,
};
use reinhardt_db::orm::Model;
use reinhardt_db::orm::fields::FieldKwarg;
use reinhardt_db::orm::inspection::{FieldInfo, RelationInfo};
use reinhardt_db::orm::relationship::RelationshipType;
#[derive(Debug, thiserror::Error)]
pub enum SchemaError {
#[error("Field conversion error: {0}")]
FieldConversion(String),
#[error("Migration execution error: {0}")]
MigrationExecution(String),
#[error("Dependency resolution error: {0}")]
DependencyResolution(String),
#[error("Circular dependency detected: {0}")]
CircularDependency(String),
}
fn convert_attributes(attributes: &HashMap<String, FieldKwarg>) -> HashMap<String, String> {
attributes
.iter()
.filter_map(|(k, v)| {
let value_str = match v {
FieldKwarg::String(s) => Some(s.clone()),
FieldKwarg::Int(n) => Some(n.to_string()),
FieldKwarg::Uint(n) => Some(n.to_string()),
FieldKwarg::Bool(b) => Some(b.to_string()),
FieldKwarg::Float(f) => Some(f.to_string()),
FieldKwarg::Choices(_) => None,
FieldKwarg::Callable(s) => Some(s.clone()),
};
value_str.map(|v| (k.clone(), v))
})
.collect()
}
pub fn field_info_to_column_definition(
field_info: &FieldInfo,
) -> Result<ColumnDefinition, SchemaError> {
let attributes = convert_attributes(&field_info.attributes);
let field_type = field_type_string_to_field_type(&field_info.field_type, &attributes)
.map_err(SchemaError::FieldConversion)?;
let name = field_info.name.clone();
let default: Option<String> = field_info.default.as_ref().map(|d| match d {
FieldKwarg::String(s) => format!("'{}'", s),
FieldKwarg::Int(n) => n.to_string(),
FieldKwarg::Uint(n) => n.to_string(),
FieldKwarg::Bool(b) => b.to_string(),
FieldKwarg::Float(f) => f.to_string(),
FieldKwarg::Choices(_) => "NULL".to_string(),
FieldKwarg::Callable(s) => s.clone(),
});
let auto_increment = field_info
.attributes
.get("auto_increment")
.map(|v| matches!(v, FieldKwarg::Bool(true)))
.unwrap_or(false)
|| field_info
.attributes
.get("identity_by_default")
.map(|v| matches!(v, FieldKwarg::Bool(true)))
.unwrap_or(false)
|| field_info.primary_key;
Ok(ColumnDefinition {
name,
type_definition: field_type,
not_null: !field_info.nullable,
unique: field_info.unique,
primary_key: field_info.primary_key,
auto_increment,
default,
})
}
pub fn extract_model_dependencies(relationship_metadata: &[RelationInfo]) -> Vec<String> {
relationship_metadata
.iter()
.filter_map(|rel| match rel.relationship_type {
RelationshipType::ManyToOne | RelationshipType::OneToOne => {
Some(rel.related_model.clone())
}
_ => None,
})
.collect()
}
pub fn resolve_table_name_for_model(
model_name: &str,
model_infos: Option<&[ModelSchemaInfo]>,
) -> String {
if let Some(infos) = model_infos {
for info in infos {
if info.name == model_name {
return info.table_name.clone();
}
}
}
to_snake_case(model_name)
}
pub fn parse_fk_action(s: &str) -> ForeignKeyAction {
match s.to_uppercase().as_str() {
"CASCADE" => ForeignKeyAction::Cascade,
"SET NULL" | "SETNULL" | "SET_NULL" => ForeignKeyAction::SetNull,
"SET DEFAULT" | "SETDEFAULT" | "SET_DEFAULT" => ForeignKeyAction::SetDefault,
"RESTRICT" => ForeignKeyAction::Restrict,
"NO ACTION" | "NOACTION" | "NO_ACTION" => ForeignKeyAction::NoAction,
_ => ForeignKeyAction::Cascade,
}
}
pub fn extract_fk_actions(
field_attrs: &HashMap<String, FieldKwarg>,
) -> (ForeignKeyAction, ForeignKeyAction) {
let on_delete = field_attrs
.get("on_delete")
.and_then(|v| match v {
FieldKwarg::String(s) => Some(parse_fk_action(s)),
_ => None,
})
.unwrap_or(ForeignKeyAction::Cascade);
let on_update = field_attrs
.get("on_update")
.and_then(|v| match v {
FieldKwarg::String(s) => Some(parse_fk_action(s)),
_ => None,
})
.unwrap_or(ForeignKeyAction::Cascade);
(on_delete, on_update)
}
pub fn infer_table_name(model_name: &str) -> String {
to_snake_case(model_name)
}
fn find_field_info_for_relation<'a>(
relation_info: &RelationInfo,
fields: &'a [FieldInfo],
) -> Option<&'a FieldInfo> {
let fk_column = relation_info.foreign_key.as_deref().unwrap_or("");
if !fk_column.is_empty()
&& let Some(field) = fields.iter().find(|f| f.name == fk_column)
{
return Some(field);
}
let derived_fk = format!("{}_id", relation_info.name);
fields.iter().find(|f| f.name == derived_fk)
}
pub fn relation_info_to_constraint(
relation_info: &RelationInfo,
source_table_name: &str,
model_infos: Option<&[ModelSchemaInfo]>,
fields: Option<&[FieldInfo]>,
) -> Option<Constraint> {
let (on_delete, on_update) = fields
.and_then(|f| find_field_info_for_relation(relation_info, f))
.map(|field_info| extract_fk_actions(&field_info.attributes))
.unwrap_or((ForeignKeyAction::Cascade, ForeignKeyAction::Cascade));
match relation_info.relationship_type {
RelationshipType::ManyToOne => {
let referenced_table =
resolve_table_name_for_model(&relation_info.related_model, model_infos);
let fk_column = relation_info
.foreign_key
.clone()
.unwrap_or_else(|| format!("{}_id", relation_info.name));
let constraint_name = format!(
"fk_{}_{}_{}_id",
source_table_name, fk_column, referenced_table
);
Some(Constraint::ForeignKey {
name: constraint_name,
columns: vec![fk_column],
referenced_table,
referenced_columns: vec!["id".to_string()],
on_delete,
on_update,
deferrable: None,
})
}
RelationshipType::OneToOne => {
let referenced_table =
resolve_table_name_for_model(&relation_info.related_model, model_infos);
let fk_column = relation_info
.foreign_key
.clone()
.unwrap_or_else(|| format!("{}_id", relation_info.name));
let constraint_name = format!(
"oo_{}_{}_{}_id",
source_table_name, fk_column, referenced_table
);
Some(Constraint::OneToOne {
name: constraint_name,
column: fk_column,
referenced_table,
referenced_column: "id".to_string(),
on_delete,
on_update,
deferrable: None,
})
}
RelationshipType::OneToMany | RelationshipType::ManyToMany => None,
}
}
pub fn resolve_model_order(models: &[(String, Vec<String>)]) -> Result<Vec<String>, SchemaError> {
use std::collections::{HashSet, VecDeque};
let model_names: HashSet<String> = models.iter().map(|(name, _)| name.clone()).collect();
let mut in_degree: HashMap<String, usize> = HashMap::new();
let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
for (name, _) in models {
in_degree.insert(name.clone(), 0);
adjacency.insert(name.clone(), Vec::new());
}
for (name, deps) in models {
for dep in deps {
if model_names.contains(dep) {
*in_degree.get_mut(name).unwrap() += 1;
adjacency.get_mut(dep).unwrap().push(name.clone());
}
}
}
let mut queue: VecDeque<String> = in_degree
.iter()
.filter(|&(_, °ree)| degree == 0)
.map(|(name, _)| name.clone())
.collect();
let mut sorted = Vec::new();
while let Some(node) = queue.pop_front() {
sorted.push(node.clone());
if let Some(neighbors) = adjacency.get(&node) {
for neighbor in neighbors {
if let Some(degree) = in_degree.get_mut(neighbor) {
*degree -= 1;
if *degree == 0 {
queue.push_back(neighbor.clone());
}
}
}
}
}
if sorted.len() != models.len() {
let sorted_set: std::collections::HashSet<_> = sorted.iter().cloned().collect();
let remaining: Vec<_> = model_names.difference(&sorted_set).collect();
return Err(SchemaError::CircularDependency(format!(
"Circular dependency detected involving: {:?}",
remaining
)));
}
Ok(sorted)
}
pub struct ModelSchemaInfo {
pub name: String,
pub table_name: String,
pub app_label: String,
pub fields: Vec<FieldInfo>,
pub relationships: Vec<RelationInfo>,
}
impl ModelSchemaInfo {
pub fn from_model<M: Model>() -> Self {
Self {
name: std::any::type_name::<M>()
.split("::")
.last()
.unwrap_or("Unknown")
.to_string(),
table_name: M::table_name().to_string(),
app_label: M::app_label().to_string(),
fields: M::field_metadata(),
relationships: M::relationship_metadata(),
}
}
pub fn dependencies(&self) -> Vec<String> {
extract_model_dependencies(&self.relationships)
}
}
pub fn create_table_operation_from_model<M: Model>() -> Result<Operation, SchemaError> {
create_table_operation_from_model_with_context::<M>(None)
}
pub fn create_table_operation_from_model_with_context<M: Model>(
model_infos: Option<&[ModelSchemaInfo]>,
) -> Result<Operation, SchemaError> {
let table_name = M::table_name().to_string();
let field_metadata = M::field_metadata();
let columns: Vec<ColumnDefinition> = field_metadata
.iter()
.map(field_info_to_column_definition)
.collect::<Result<Vec<_>, _>>()?;
let constraints: Vec<Constraint> = M::relationship_metadata()
.iter()
.filter_map(|rel| {
relation_info_to_constraint(rel, &table_name, model_infos, Some(&field_metadata))
})
.collect();
Ok(Operation::CreateTable {
name: table_name,
columns,
constraints,
without_rowid: None,
interleave_in_parent: None,
partition: None,
})
}
pub fn create_migration_from_model<M: Model>(
migration_name: &str,
) -> Result<Migration, SchemaError> {
let operation = create_table_operation_from_model::<M>()?;
Ok(Migration {
name: migration_name.to_string(),
app_label: M::app_label().to_string(),
operations: vec![operation],
dependencies: vec![],
replaces: vec![],
atomic: true,
initial: Some(true),
state_only: false,
database_only: false,
optional_dependencies: vec![],
swappable_dependencies: vec![],
})
}
pub fn create_table_operations_from_models(
model_infos: Vec<ModelSchemaInfo>,
) -> Result<Vec<Operation>, SchemaError> {
let models_with_deps: Vec<(String, Vec<String>)> = model_infos
.iter()
.map(|info| (info.name.clone(), info.dependencies()))
.collect();
let ordered_names = resolve_model_order(&models_with_deps)?;
let name_to_info: HashMap<String, &ModelSchemaInfo> = model_infos
.iter()
.map(|info| (info.name.clone(), info))
.collect();
let mut operations = Vec::new();
for name in ordered_names {
if let Some(info) = name_to_info.get(&name) {
let columns: Vec<ColumnDefinition> = info
.fields
.iter()
.map(field_info_to_column_definition)
.collect::<Result<Vec<_>, _>>()?;
let constraints: Vec<Constraint> = info
.relationships
.iter()
.filter_map(|rel| {
relation_info_to_constraint(
rel,
&info.table_name,
Some(&model_infos),
Some(&info.fields),
)
})
.collect();
operations.push(Operation::CreateTable {
name: info.table_name.clone(),
columns,
constraints,
without_rowid: None,
interleave_in_parent: None,
partition: None,
});
}
}
Ok(operations)
}
pub async fn create_table_for_model<M: Model>(
connection: &DatabaseConnection,
) -> Result<(), SchemaError> {
let migration = create_migration_from_model::<M>("0001_auto_create")?;
let mut executor = DatabaseMigrationExecutor::new(connection.clone());
executor
.apply_migrations(&[migration])
.await
.map_err(|e| SchemaError::MigrationExecution(e.to_string()))?;
Ok(())
}
pub async fn create_tables_for_models(
connection: &DatabaseConnection,
model_infos: Vec<ModelSchemaInfo>,
) -> Result<(), SchemaError> {
let operations = create_table_operations_from_models(model_infos)?;
if operations.is_empty() {
return Ok(());
}
let migration = Migration {
name: "0001_auto_batch_create".to_string(),
app_label: "test".to_string(),
operations,
dependencies: vec![],
replaces: vec![],
atomic: true,
initial: Some(true),
state_only: false,
database_only: false,
optional_dependencies: vec![],
swappable_dependencies: vec![],
};
let mut executor = DatabaseMigrationExecutor::new(connection.clone());
executor
.apply_migrations(&[migration])
.await
.map_err(|e| SchemaError::MigrationExecution(e.to_string()))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
fn test_resolve_model_order_simple() {
let models = vec![
("User".to_string(), vec![]),
("Post".to_string(), vec!["User".to_string()]),
(
"Comment".to_string(),
vec!["Post".to_string(), "User".to_string()],
),
];
let order = resolve_model_order(&models).unwrap();
let user_idx = order.iter().position(|n| n == "User").unwrap();
let post_idx = order.iter().position(|n| n == "Post").unwrap();
let comment_idx = order.iter().position(|n| n == "Comment").unwrap();
assert!(user_idx < post_idx);
assert!(user_idx < comment_idx);
assert!(post_idx < comment_idx);
}
#[rstest]
fn test_resolve_model_order_circular() {
let models = vec![
("A".to_string(), vec!["B".to_string()]),
("B".to_string(), vec!["A".to_string()]),
];
let result = resolve_model_order(&models);
assert!(result.is_err());
}
#[rstest]
fn test_resolve_model_order_external_deps() {
let models = vec![
("User".to_string(), vec!["ExternalModel".to_string()]),
("Post".to_string(), vec!["User".to_string()]),
];
let order = resolve_model_order(&models).unwrap();
assert_eq!(order.len(), 2);
}
#[rstest]
fn test_extract_model_dependencies() {
use reinhardt_db::orm::inspection::RelationInfo;
let relations = vec![
RelationInfo::new(
"author".to_string(),
RelationshipType::ManyToOne,
"User".to_string(),
),
RelationInfo::new(
"tags".to_string(),
RelationshipType::ManyToMany,
"Tag".to_string(),
),
];
let deps = extract_model_dependencies(&relations);
assert_eq!(deps.len(), 1);
assert!(deps.contains(&"User".to_string()));
}
#[rstest]
fn test_resolve_table_name_for_model_with_model_infos() {
let model_infos = vec![ModelSchemaInfo {
name: "User".to_string(),
table_name: "custom_users".to_string(),
app_label: "test".to_string(),
fields: vec![],
relationships: vec![],
}];
let table_name = resolve_table_name_for_model("User", Some(&model_infos));
assert_eq!(table_name, "custom_users");
}
#[rstest]
fn test_resolve_table_name_for_model_fallback_to_snake_case() {
let table_name = resolve_table_name_for_model("BlogPost", None);
assert_eq!(table_name, "blog_post");
}
#[rstest]
fn test_resolve_table_name_for_model_not_found_in_infos() {
let model_infos = vec![ModelSchemaInfo {
name: "User".to_string(),
table_name: "users".to_string(),
app_label: "test".to_string(),
fields: vec![],
relationships: vec![],
}];
let table_name = resolve_table_name_for_model("BlogPost", Some(&model_infos));
assert_eq!(table_name, "blog_post");
}
#[rstest]
fn test_relation_info_to_constraint_many_to_one() {
let relation = RelationInfo::new("author", RelationshipType::ManyToOne, "User")
.with_foreign_key("author_id");
let constraint = relation_info_to_constraint(&relation, "posts", None, None);
assert!(constraint.is_some());
match constraint.unwrap() {
Constraint::ForeignKey {
name,
columns,
referenced_table,
referenced_columns,
on_delete,
on_update,
..
} => {
assert_eq!(name, "fk_posts_author_id_user_id");
assert_eq!(columns, vec!["author_id".to_string()]);
assert_eq!(referenced_table, "user");
assert_eq!(referenced_columns, vec!["id".to_string()]);
assert!(matches!(on_delete, ForeignKeyAction::Cascade));
assert!(matches!(on_update, ForeignKeyAction::Cascade));
}
_ => panic!("Expected ForeignKey constraint"),
}
}
#[rstest]
fn test_relation_info_to_constraint_many_to_one_without_explicit_fk() {
let relation = RelationInfo::new("author", RelationshipType::ManyToOne, "User");
let constraint = relation_info_to_constraint(&relation, "posts", None, None);
assert!(constraint.is_some());
match constraint.unwrap() {
Constraint::ForeignKey { columns, .. } => {
assert_eq!(columns, vec!["author_id".to_string()]);
}
_ => panic!("Expected ForeignKey constraint"),
}
}
#[rstest]
fn test_relation_info_to_constraint_one_to_one() {
let relation = RelationInfo::new("profile", RelationshipType::OneToOne, "UserProfile")
.with_foreign_key("profile_id");
let constraint = relation_info_to_constraint(&relation, "users", None, None);
assert!(constraint.is_some());
match constraint.unwrap() {
Constraint::OneToOne {
name,
column,
referenced_table,
referenced_column,
on_delete,
on_update,
..
} => {
assert_eq!(name, "oo_users_profile_id_user_profile_id");
assert_eq!(column, "profile_id");
assert_eq!(referenced_table, "user_profile");
assert_eq!(referenced_column, "id");
assert!(matches!(on_delete, ForeignKeyAction::Cascade));
assert!(matches!(on_update, ForeignKeyAction::Cascade));
}
_ => panic!("Expected OneToOne constraint"),
}
}
#[rstest]
fn test_relation_info_to_constraint_one_to_many_returns_none() {
let relation = RelationInfo::new("posts", RelationshipType::OneToMany, "Post");
let constraint = relation_info_to_constraint(&relation, "users", None, None);
assert!(constraint.is_none());
}
#[rstest]
fn test_relation_info_to_constraint_many_to_many_returns_none() {
let relation = RelationInfo::new("tags", RelationshipType::ManyToMany, "Tag");
let constraint = relation_info_to_constraint(&relation, "posts", None, None);
assert!(constraint.is_none());
}
#[rstest]
fn test_relation_info_to_constraint_with_model_infos() {
let model_infos = vec![ModelSchemaInfo {
name: "User".to_string(),
table_name: "app_users".to_string(),
app_label: "test".to_string(),
fields: vec![],
relationships: vec![],
}];
let relation = RelationInfo::new("author", RelationshipType::ManyToOne, "User")
.with_foreign_key("author_id");
let constraint = relation_info_to_constraint(&relation, "posts", Some(&model_infos), None);
assert!(constraint.is_some());
match constraint.unwrap() {
Constraint::ForeignKey {
referenced_table, ..
} => {
assert_eq!(referenced_table, "app_users");
}
_ => panic!("Expected ForeignKey constraint"),
}
}
#[rstest]
fn test_parse_fk_action_cascade() {
assert!(matches!(
parse_fk_action("CASCADE"),
ForeignKeyAction::Cascade
));
assert!(matches!(
parse_fk_action("cascade"),
ForeignKeyAction::Cascade
));
assert!(matches!(
parse_fk_action("Cascade"),
ForeignKeyAction::Cascade
));
}
#[rstest]
fn test_parse_fk_action_set_null() {
assert!(matches!(
parse_fk_action("SET NULL"),
ForeignKeyAction::SetNull
));
assert!(matches!(
parse_fk_action("SETNULL"),
ForeignKeyAction::SetNull
));
assert!(matches!(
parse_fk_action("SET_NULL"),
ForeignKeyAction::SetNull
));
}
#[rstest]
fn test_parse_fk_action_restrict() {
assert!(matches!(
parse_fk_action("RESTRICT"),
ForeignKeyAction::Restrict
));
}
#[rstest]
fn test_parse_fk_action_no_action() {
assert!(matches!(
parse_fk_action("NO ACTION"),
ForeignKeyAction::NoAction
));
assert!(matches!(
parse_fk_action("NOACTION"),
ForeignKeyAction::NoAction
));
assert!(matches!(
parse_fk_action("NO_ACTION"),
ForeignKeyAction::NoAction
));
}
#[rstest]
fn test_parse_fk_action_set_default() {
assert!(matches!(
parse_fk_action("SET DEFAULT"),
ForeignKeyAction::SetDefault
));
assert!(matches!(
parse_fk_action("SETDEFAULT"),
ForeignKeyAction::SetDefault
));
assert!(matches!(
parse_fk_action("SET_DEFAULT"),
ForeignKeyAction::SetDefault
));
}
#[rstest]
fn test_parse_fk_action_unknown_defaults_to_cascade() {
assert!(matches!(
parse_fk_action("UNKNOWN"),
ForeignKeyAction::Cascade
));
assert!(matches!(parse_fk_action(""), ForeignKeyAction::Cascade));
}
#[rstest]
fn test_extract_fk_actions_with_both_actions() {
let mut attrs = HashMap::new();
attrs.insert(
"on_delete".to_string(),
FieldKwarg::String("SET NULL".to_string()),
);
attrs.insert(
"on_update".to_string(),
FieldKwarg::String("RESTRICT".to_string()),
);
let (on_delete, on_update) = extract_fk_actions(&attrs);
assert!(matches!(on_delete, ForeignKeyAction::SetNull));
assert!(matches!(on_update, ForeignKeyAction::Restrict));
}
#[rstest]
fn test_extract_fk_actions_with_only_on_delete() {
let mut attrs = HashMap::new();
attrs.insert(
"on_delete".to_string(),
FieldKwarg::String("RESTRICT".to_string()),
);
let (on_delete, on_update) = extract_fk_actions(&attrs);
assert!(matches!(on_delete, ForeignKeyAction::Restrict));
assert!(matches!(on_update, ForeignKeyAction::Cascade));
}
#[rstest]
fn test_extract_fk_actions_empty_attrs_defaults_to_cascade() {
let attrs = HashMap::new();
let (on_delete, on_update) = extract_fk_actions(&attrs);
assert!(matches!(on_delete, ForeignKeyAction::Cascade));
assert!(matches!(on_update, ForeignKeyAction::Cascade));
}
#[rstest]
fn test_infer_table_name() {
assert_eq!(infer_table_name("BlogPost"), "blog_post");
assert_eq!(infer_table_name("UserProfile"), "user_profile");
assert_eq!(infer_table_name("User"), "user");
}
#[rstest]
fn test_relation_info_to_constraint_with_field_attrs() {
let mut attrs = HashMap::new();
attrs.insert(
"on_delete".to_string(),
FieldKwarg::String("SET NULL".to_string()),
);
attrs.insert(
"on_update".to_string(),
FieldKwarg::String("NO ACTION".to_string()),
);
let field_info = FieldInfo {
name: "author_id".to_string(),
field_type: "BigInteger".to_string(),
nullable: true,
primary_key: false,
unique: false,
blank: false,
editable: true,
default: None,
db_default: None,
db_column: None,
choices: None,
attributes: attrs,
};
let relation = RelationInfo::new("author", RelationshipType::ManyToOne, "User")
.with_foreign_key("author_id");
let constraint = relation_info_to_constraint(&relation, "posts", None, Some(&[field_info]));
assert!(constraint.is_some());
match constraint.unwrap() {
Constraint::ForeignKey {
on_delete,
on_update,
..
} => {
assert!(matches!(on_delete, ForeignKeyAction::SetNull));
assert!(matches!(on_update, ForeignKeyAction::NoAction));
}
_ => panic!("Expected ForeignKey constraint"),
}
}
}