1use std::marker::PhantomData;
5
6use crate::model::Model;
7use crate::pool::Pool;
8use crate::Error;
9
10pub trait RelationKind: Send + Sync + 'static {}
11
12pub struct HasMany;
13impl RelationKind for HasMany {}
14
15pub struct HasOne;
16impl RelationKind for HasOne {}
17
18pub struct BelongsTo;
19impl RelationKind for BelongsTo {}
20
21pub struct BelongsToMany;
22impl RelationKind for BelongsToMany {}
23
24pub trait RelationDef: Send + Sync {
25 type Parent: Model;
26 type Child: Model;
27 type Kind: RelationKind;
28 fn local_key() -> &'static str;
29 fn foreign_key() -> &'static str;
30}
31
32pub struct Relation<P: Model, C: Model, K: RelationKind> {
34 pub parent_value: serde_json::Value,
35 pub local_key: &'static str,
36 pub foreign_key: &'static str,
37 _marker: PhantomData<(P, C, K)>,
38}
39
40impl<P: Model, C: Model> Relation<P, C, HasMany> {
41 pub async fn load(self, pool: &Pool) -> Result<Vec<C>, Error> {
42 let sql = format!(
43 "SELECT {} FROM {} WHERE {} = $1",
44 C::COLUMNS.join(", "),
45 C::TABLE,
46 self.foreign_key,
47 );
48 let rows = sqlx::query_as::<_, C>(&sql)
49 .bind(self.parent_value)
50 .fetch_all(pool)
51 .await?;
52 Ok(rows)
53 }
54}
55
56impl<P: Model, C: Model> Relation<P, C, BelongsTo> {
57 pub async fn load(self, pool: &Pool) -> Result<Option<C>, Error> {
58 let sql = format!(
59 "SELECT {} FROM {} WHERE {} = $1 LIMIT 1",
60 C::COLUMNS.join(", "),
61 C::TABLE,
62 self.local_key,
63 );
64 let row = sqlx::query_as::<_, C>(&sql)
65 .bind(self.parent_value)
66 .fetch_optional(pool)
67 .await?;
68 Ok(row)
69 }
70}
71
72impl<P: Model, C: Model> Relation<P, C, HasOne> {
73 pub async fn load(self, pool: &Pool) -> Result<Option<C>, Error> {
74 let sql = format!(
75 "SELECT {} FROM {} WHERE {} = $1 LIMIT 1",
76 C::COLUMNS.join(", "),
77 C::TABLE,
78 self.foreign_key,
79 );
80 let row = sqlx::query_as::<_, C>(&sql)
81 .bind(self.parent_value)
82 .fetch_optional(pool)
83 .await?;
84 Ok(row)
85 }
86}
87
88impl<P: Model, C: Model, K: RelationKind> Relation<P, C, K> {
89 pub fn new(
90 parent_value: serde_json::Value,
91 local_key: &'static str,
92 foreign_key: &'static str,
93 ) -> Self {
94 Self {
95 parent_value,
96 local_key,
97 foreign_key,
98 _marker: PhantomData,
99 }
100 }
101}