1use std::marker::PhantomData;
5
6use crate::model::Model;
7use crate::Error;
8
9pub trait RelationKind: Send + Sync + 'static {}
10
11pub struct HasMany;
12impl RelationKind for HasMany {}
13
14pub struct HasOne;
15impl RelationKind for HasOne {}
16
17pub struct BelongsTo;
18impl RelationKind for BelongsTo {}
19
20pub struct BelongsToMany;
21impl RelationKind for BelongsToMany {}
22
23pub trait RelationDef: Send + Sync {
24 type Parent: Model;
25 type Child: Model;
26 type Kind: RelationKind;
27 fn local_key() -> &'static str;
28 fn foreign_key() -> &'static str;
29}
30
31pub struct Relation<P: Model, C: Model, K: RelationKind> {
33 pub parent_value: serde_json::Value,
34 pub local_key: &'static str,
35 pub foreign_key: &'static str,
36 _marker: PhantomData<(P, C, K)>,
37}
38
39impl<P: Model, C: Model> Relation<P, C, HasMany> {
40 pub async fn load(self, pool: &sqlx::PgPool) -> Result<Vec<C>, Error> {
41 let sql = format!(
42 "SELECT {} FROM {} WHERE {} = $1",
43 C::COLUMNS.join(", "),
44 C::TABLE,
45 self.foreign_key,
46 );
47 let rows = sqlx::query_as::<_, C>(&sql)
48 .bind(self.parent_value)
49 .fetch_all(pool)
50 .await?;
51 Ok(rows)
52 }
53}
54
55impl<P: Model, C: Model> Relation<P, C, BelongsTo> {
56 pub async fn load(self, pool: &sqlx::PgPool) -> Result<Option<C>, Error> {
57 let sql = format!(
58 "SELECT {} FROM {} WHERE {} = $1 LIMIT 1",
59 C::COLUMNS.join(", "),
60 C::TABLE,
61 self.local_key,
62 );
63 let row = sqlx::query_as::<_, C>(&sql)
64 .bind(self.parent_value)
65 .fetch_optional(pool)
66 .await?;
67 Ok(row)
68 }
69}
70
71impl<P: Model, C: Model> Relation<P, C, HasOne> {
72 pub async fn load(self, pool: &sqlx::PgPool) -> Result<Option<C>, Error> {
73 let sql = format!(
74 "SELECT {} FROM {} WHERE {} = $1 LIMIT 1",
75 C::COLUMNS.join(", "),
76 C::TABLE,
77 self.foreign_key,
78 );
79 let row = sqlx::query_as::<_, C>(&sql)
80 .bind(self.parent_value)
81 .fetch_optional(pool)
82 .await?;
83 Ok(row)
84 }
85}
86
87impl<P: Model, C: Model, K: RelationKind> Relation<P, C, K> {
88 pub fn new(
89 parent_value: serde_json::Value,
90 local_key: &'static str,
91 foreign_key: &'static str,
92 ) -> Self {
93 Self {
94 parent_value,
95 local_key,
96 foreign_key,
97 _marker: PhantomData,
98 }
99 }
100}