use std::collections::{HashMap, HashSet};
use rustrails_macros::{BelongsToAssociation, HasManyAssociation, HasOneAssociation};
use rustrails_support::{
database,
inflector::{foreign_key, singularize},
runtime,
};
use sea_orm::{ColumnTrait, EntityTrait, Iterable};
use serde::Serialize;
use serde_json::{Value, json};
use crate::{Querying, Record, RecordError, Relation};
pub mod belongs_to;
pub mod has_and_belongs_to_many;
pub mod has_many;
pub mod has_one;
pub use belongs_to::BelongsToBuilder;
pub use has_and_belongs_to_many::HasAndBelongsToManyBuilder;
pub use has_many::HasManyBuilder;
pub use has_one::HasOneBuilder;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AssociationType {
HasMany,
HasOne,
BelongsTo,
HasAndBelongsToMany,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DependentAction {
Destroy,
Delete,
Nullify,
Restrict,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AssociationMeta {
pub name: String,
pub association_type: AssociationType,
pub target_table: String,
pub foreign_key: String,
pub primary_key: String,
pub dependent: Option<DependentAction>,
pub through: Option<String>,
pub polymorphic: bool,
}
#[derive(Debug, Default)]
pub struct AssociationRegistry {
associations: Vec<AssociationMeta>,
}
impl AssociationRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, meta: AssociationMeta) {
self.associations.push(meta);
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&AssociationMeta> {
self.associations.iter().find(|meta| meta.name == name)
}
#[must_use]
pub fn of_type(&self, assoc_type: AssociationType) -> Vec<&AssociationMeta> {
self.associations
.iter()
.filter(|meta| meta.association_type == assoc_type)
.collect()
}
#[must_use]
pub fn all(&self) -> &[AssociationMeta] {
&self.associations
}
}
pub trait HasAssociations: Record {
fn associations() -> &'static AssociationRegistry;
}
pub trait HasManyQuery<Target> {
fn has_many(&self) -> Result<Vec<Target>, RecordError>
where
Self: Record + Serialize + HasManyAssociation<Target>,
Target: Querying,
<Target::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let Some(owner_id) = self.id() else {
return Ok(Vec::new());
};
let definition = <Self as HasManyAssociation<Target>>::association_definition();
let foreign_key = definition
.foreign_key
.map(str::to_owned)
.unwrap_or_else(|| default_owner_foreign_key::<Self>());
load_many_by_field::<Target>(&foreign_key, owner_id)
}
}
pub trait HasManyThroughQuery<Target, Join> {
fn has_many(&self) -> Result<Vec<Target>, RecordError>
where
Self: Record + Serialize + HasManyAssociation<Target>,
Target: Querying,
Join: Querying + Serialize,
<Target::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
<Join::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let Some(owner_id) = self.id() else {
return Ok(Vec::new());
};
let join_owner_key = default_owner_foreign_key::<Self>();
let join_target_key = default_target_foreign_key::<Target>();
let join_rows = load_many_by_field::<Join>(&join_owner_key, owner_id)?;
let mut target_ids = HashSet::new();
let mut targets = Vec::new();
for join_row in join_rows {
let Some(target_id) = extract_serialized_id(&join_row, &join_target_key)? else {
continue;
};
if target_ids.insert(target_id) {
targets.push(Target::find_sync(target_id)?);
}
}
Ok(targets)
}
}
pub trait BelongsToQuery<Target> {
fn belongs_to(&self) -> Result<Target, RecordError>
where
Self: Record + Serialize + BelongsToAssociation<Target>,
Target: Querying,
<Target::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let definition = <Self as BelongsToAssociation<Target>>::association_definition();
let foreign_key = definition
.foreign_key
.map(str::to_owned)
.unwrap_or_else(|| default_target_foreign_key::<Target>());
let target_id = extract_serialized_id(self, &foreign_key)?.ok_or(RecordError::NotFound)?;
Target::find_sync(target_id)
}
}
pub trait HasOneQuery<Target> {
fn has_one(&self) -> Result<Option<Target>, RecordError>
where
Self: Record + Serialize + HasOneAssociation<Target>,
Target: Querying,
<Target::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
let Some(owner_id) = self.id() else {
return Ok(None);
};
let definition = <Self as HasOneAssociation<Target>>::association_definition();
let foreign_key = definition
.foreign_key
.map(str::to_owned)
.unwrap_or_else(|| default_owner_foreign_key::<Self>());
load_one_by_field::<Target>(&foreign_key, owner_id)
}
}
fn default_owner_foreign_key<Model: Record>() -> String {
foreign_key(&singularize(Model::table_name()))
}
fn default_target_foreign_key<Target: Record>() -> String {
foreign_key(&singularize(Target::table_name()))
}
fn load_many_by_field<Target>(field: &str, value: i64) -> Result<Vec<Target>, RecordError>
where
Target: Querying,
<Target::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| {
runtime::block_on(
Relation::<Target>::new()
.r#where(HashMap::from([(field.to_owned(), json!(value))]))
.load(db),
)
})
}
fn load_one_by_field<Target>(field: &str, value: i64) -> Result<Option<Target>, RecordError>
where
Target: Querying,
<Target::Entity as EntityTrait>::Column: ColumnTrait + Iterable,
{
database::with_db(|db| {
runtime::block_on(
Relation::<Target>::new()
.r#where(HashMap::from([(field.to_owned(), json!(value))]))
.first(db),
)
})
}
fn extract_serialized_id<T: Serialize>(
record: &T,
field: &str,
) -> Result<Option<i64>, RecordError> {
let value =
serde_json::to_value(record).map_err(|error| RecordError::Invalid(error.to_string()))?;
let object = value.as_object().ok_or_else(|| {
RecordError::Invalid("associated record must serialize to a JSON object".to_owned())
})?;
match object.get(field) {
Some(Value::Null) => Ok(None),
Some(value) => Ok(Some(json_value_to_i64(value, field)?)),
None => Err(RecordError::Invalid(format!(
"missing association key `{field}` on serialized record"
))),
}
}
fn json_value_to_i64(value: &Value, field: &str) -> Result<i64, RecordError> {
match value {
Value::Number(number) => {
if let Some(value) = number.as_i64() {
Ok(value)
} else if let Some(value) = number.as_u64() {
i64::try_from(value).map_err(|_| {
RecordError::Invalid(format!("association key `{field}` does not fit in i64"))
})
} else {
Err(RecordError::Invalid(format!(
"association key `{field}` must be an integer"
)))
}
}
_ => Err(RecordError::Invalid(format!(
"association key `{field}` must be numeric"
))),
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::LazyLock};
use rustrails_macros::{
AssociationKind, BelongsToAssociation, HasManyAssociation, HasOneAssociation, belongs_to,
has_many, has_one, model,
};
use rustrails_support::{database, runtime};
use serde_json::json;
use crate::{
Persistence,
associations::{
AssociationRegistry, AssociationType, BelongsToBuilder, BelongsToQuery,
DependentAction, HasAndBelongsToManyBuilder, HasAssociations, HasManyBuilder,
HasManyQuery, HasOneBuilder, HasOneQuery,
},
base::test_support::TestUser,
};
model! {
QueryBlog {
title: String,
}
table_name: "query_blogs";
}
model! {
QueryPost {
query_blog_id: i64,
title: String,
}
table_name: "query_posts";
}
model! {
QueryProfile {
query_blog_id: i64,
bio: String,
}
table_name: "query_profiles";
}
has_many!(QueryBlog => QueryPost, foreign_key: query_blog_id);
belongs_to!(QueryPost => QueryBlog, foreign_key: query_blog_id);
has_one!(QueryBlog => QueryProfile);
struct DefaultHasManyAuthor;
struct DefaultHasManyPost;
struct ForeignKeyHasManyAuthor;
struct ForeignKeyHasManyPost;
struct ThroughHasManyAuthor;
struct ThroughHasManyTag;
struct PostTagging;
struct DefaultBelongsToComment;
struct DefaultBelongsToPost;
struct ForeignKeyBelongsToComment;
struct ForeignKeyBelongsToBlog;
struct DefaultHasOneUser;
struct DefaultHasOneProfile;
has_many!(DefaultHasManyAuthor => DefaultHasManyPost);
has_many!(ForeignKeyHasManyAuthor => ForeignKeyHasManyPost, foreign_key: author_id);
has_many!(ThroughHasManyAuthor => ThroughHasManyTag, through: PostTagging);
belongs_to!(DefaultBelongsToComment => DefaultBelongsToPost);
belongs_to!(ForeignKeyBelongsToComment => ForeignKeyBelongsToBlog, foreign_key: blog_id);
has_one!(DefaultHasOneUser => DefaultHasOneProfile);
static TEST_ASSOCIATIONS: LazyLock<AssociationRegistry> = LazyLock::new(|| {
let mut registry = AssociationRegistry::new();
registry.add(
HasManyBuilder::new("comments")
.dependent(DependentAction::Destroy)
.build(),
);
registry.add(HasOneBuilder::new("profile").build());
registry.add(BelongsToBuilder::new("account").build());
registry.add(
HasAndBelongsToManyBuilder::new("roles")
.through("accounts_roles")
.build(),
);
registry
});
impl HasAssociations for TestUser {
fn associations() -> &'static AssociationRegistry {
&TEST_ASSOCIATIONS
}
}
#[test]
fn registry_returns_named_association() {
let association = TestUser::associations()
.get("comments")
.expect("comments association should exist");
assert_eq!(association.association_type, AssociationType::HasMany);
assert_eq!(association.dependent, Some(DependentAction::Destroy));
}
#[test]
fn registry_filters_associations_by_type() {
let has_many = TestUser::associations().of_type(AssociationType::HasMany);
let belongs_to = TestUser::associations().of_type(AssociationType::BelongsTo);
assert_eq!(has_many.len(), 1);
assert_eq!(has_many[0].name, "comments");
assert_eq!(belongs_to.len(), 1);
assert_eq!(belongs_to[0].name, "account");
}
#[test]
fn registry_exposes_all_associations_in_order() {
let names = TestUser::associations()
.all()
.iter()
.map(|meta| meta.name.as_str())
.collect::<Vec<_>>();
assert_eq!(names, vec!["comments", "profile", "account", "roles"]);
}
#[test]
fn registry_returns_none_for_unknown_association() {
assert!(TestUser::associations().get("missing").is_none());
}
#[test]
fn new_registry_starts_empty() {
let registry = AssociationRegistry::new();
assert!(registry.all().is_empty());
}
#[test]
fn add_appends_associations() {
let mut registry = AssociationRegistry::new();
registry.add(HasManyBuilder::new("comments").build());
registry.add(HasOneBuilder::new("profile").build());
assert_eq!(registry.all().len(), 2);
}
#[test]
fn get_is_case_sensitive() {
assert!(TestUser::associations().get("Comments").is_none());
}
#[test]
fn get_returns_first_matching_name_when_duplicates_exist() {
let mut registry = AssociationRegistry::new();
registry.add(HasManyBuilder::new("comments").build());
registry.add(
HasManyBuilder::new("comments")
.foreign_key("owner_id")
.build(),
);
let association = registry.get("comments").expect("association should exist");
assert_eq!(association.foreign_key, "comment_id");
}
#[test]
fn of_type_returns_empty_when_no_associations_match() {
let registry = AssociationRegistry::new();
assert!(registry.of_type(AssociationType::HasMany).is_empty());
}
#[test]
fn of_type_preserves_declaration_order() {
let mut registry = AssociationRegistry::new();
registry.add(HasManyBuilder::new("comments").build());
registry.add(HasManyBuilder::new("tags").build());
let names = registry
.of_type(AssociationType::HasMany)
.into_iter()
.map(|meta| meta.name.as_str())
.collect::<Vec<_>>();
assert_eq!(names, vec!["comments", "tags"]);
}
#[test]
fn all_returns_empty_slice_for_new_registry() {
let registry = AssociationRegistry::new();
assert_eq!(registry.all(), &[]);
}
#[test]
fn associations_registry_is_stable_across_calls() {
assert!(std::ptr::eq(
TestUser::associations(),
TestUser::associations()
));
}
fn association_kind_name(kind: AssociationKind) -> &'static str {
match kind {
AssociationKind::HasMany => "has_many",
AssociationKind::BelongsTo => "belongs_to",
AssociationKind::HasOne => "has_one",
}
}
#[test]
fn default_has_many_definition_sets_has_many_kind() {
let definition =
<DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition();
assert_eq!(definition.kind, AssociationKind::HasMany);
}
#[test]
fn default_has_many_definition_records_model_name() {
let definition =
<DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition();
assert_eq!(definition.model, "DefaultHasManyAuthor");
}
#[test]
fn default_has_many_definition_records_target_name() {
let definition =
<DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition();
assert_eq!(definition.target, "DefaultHasManyPost");
}
#[test]
fn default_has_many_definition_has_no_foreign_key_override() {
let definition =
<DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition();
assert_eq!(definition.foreign_key, None);
}
#[test]
fn default_has_many_definition_has_no_through_target() {
let definition =
<DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition();
assert_eq!(definition.through, None);
}
#[test]
fn foreign_key_has_many_definition_records_foreign_key_override() {
let definition = <ForeignKeyHasManyAuthor as HasManyAssociation<ForeignKeyHasManyPost>>::association_definition();
assert_eq!(definition.foreign_key, Some("author_id"));
}
#[test]
fn through_has_many_definition_records_through_target() {
let definition =
<ThroughHasManyAuthor as HasManyAssociation<ThroughHasManyTag>>::association_definition(
);
assert_eq!(definition.through, Some("PostTagging"));
}
#[test]
fn has_many_association_definition_is_stable_across_calls() {
assert_eq!(
<DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition(),
<DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition(),
);
}
#[test]
fn default_belongs_to_definition_sets_belongs_to_kind() {
let definition = <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition();
assert_eq!(definition.kind, AssociationKind::BelongsTo);
}
#[test]
fn default_belongs_to_definition_records_model_name() {
let definition = <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition();
assert_eq!(definition.model, "DefaultBelongsToComment");
}
#[test]
fn default_belongs_to_definition_records_target_name() {
let definition = <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition();
assert_eq!(definition.target, "DefaultBelongsToPost");
}
#[test]
fn default_belongs_to_definition_has_no_foreign_key_override() {
let definition = <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition();
assert_eq!(definition.foreign_key, None);
}
#[test]
fn default_belongs_to_definition_has_no_through_target() {
let definition = <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition();
assert_eq!(definition.through, None);
}
#[test]
fn foreign_key_belongs_to_definition_records_foreign_key_override() {
let definition = <ForeignKeyBelongsToComment as BelongsToAssociation<
ForeignKeyBelongsToBlog,
>>::association_definition();
assert_eq!(definition.foreign_key, Some("blog_id"));
}
#[test]
fn belongs_to_association_definition_is_stable_across_calls() {
assert_eq!(
<DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition(),
<DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition(),
);
}
#[test]
fn default_has_one_definition_sets_has_one_kind() {
let definition =
<DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
);
assert_eq!(definition.kind, AssociationKind::HasOne);
}
#[test]
fn default_has_one_definition_records_model_name() {
let definition =
<DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
);
assert_eq!(definition.model, "DefaultHasOneUser");
}
#[test]
fn default_has_one_definition_records_target_name() {
let definition =
<DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
);
assert_eq!(definition.target, "DefaultHasOneProfile");
}
#[test]
fn default_has_one_definition_has_no_foreign_key_override() {
let definition =
<DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
);
assert_eq!(definition.foreign_key, None);
}
#[test]
fn default_has_one_definition_has_no_through_target() {
let definition =
<DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
);
assert_eq!(definition.through, None);
}
#[test]
fn has_one_association_definition_is_stable_across_calls() {
assert_eq!(
<DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
),
<DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
),
);
}
#[test]
fn has_many_definition_kind_matches_has_many_branch() {
let definition =
<DefaultHasManyAuthor as HasManyAssociation<DefaultHasManyPost>>::association_definition();
assert_eq!(association_kind_name(definition.kind), "has_many");
}
#[test]
fn belongs_to_definition_kind_matches_belongs_to_branch() {
let definition = <DefaultBelongsToComment as BelongsToAssociation<DefaultBelongsToPost>>::association_definition();
assert_eq!(association_kind_name(definition.kind), "belongs_to");
}
#[test]
fn has_one_definition_kind_matches_has_one_branch() {
let definition =
<DefaultHasOneUser as HasOneAssociation<DefaultHasOneProfile>>::association_definition(
);
assert_eq!(association_kind_name(definition.kind), "has_one");
}
#[test]
fn association_query_traits_load_related_records() {
let _runtime = runtime::init_runtime();
database::establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
runtime::block_on(async {
use sea_orm::ConnectionTrait;
let db = database::db();
db.execute_unprepared(
"CREATE TABLE query_blogs (id INTEGER PRIMARY KEY AUTOINCREMENT, title TEXT NOT NULL)",
)
.await
.expect("query_blogs table should be created");
db.execute_unprepared(
"CREATE TABLE query_posts (id INTEGER PRIMARY KEY AUTOINCREMENT, query_blog_id INTEGER NOT NULL, title TEXT NOT NULL)",
)
.await
.expect("query_posts table should be created");
db.execute_unprepared(
"CREATE TABLE query_profiles (id INTEGER PRIMARY KEY AUTOINCREMENT, query_blog_id INTEGER NOT NULL, bio TEXT NOT NULL)",
)
.await
.expect("query_profiles table should be created");
});
let blog = QueryBlog::create_sync(HashMap::from([("title".to_owned(), json!("Main"))]))
.expect("blog should be created");
let blog_id = blog.id.expect("blog should have an id");
let post = QueryPost::create_sync(HashMap::from([
("query_blog_id".to_owned(), json!(blog_id)),
("title".to_owned(), json!("First")),
]))
.expect("post should be created");
QueryProfile::create_sync(HashMap::from([
("query_blog_id".to_owned(), json!(blog_id)),
("bio".to_owned(), json!("About the blog")),
]))
.expect("profile should be created");
let posts: Vec<QueryPost> = blog.has_many().expect("has_many should load related posts");
assert_eq!(posts.len(), 1);
assert_eq!(posts[0].title, "First");
let owner: QueryBlog = post.belongs_to().expect("belongs_to should load the owner");
assert_eq!(owner.title, "Main");
let profile = blog
.has_one()
.expect("has_one should query the related record");
assert_eq!(profile.expect("profile should exist").bio, "About the blog");
}
}