Skip to main content

cast_core/
relation.rs

1//! Relationship definitions. `RelationDef` is implemented by zero-sized types
2//! generated by `#[has_many]` / `#[belongs_to]` attribute parsing.
3
4use std::marker::PhantomData;
5
6use crate::model::Model;
7use crate::pool::Pool;
8use crate::Error;
9
10pub trait RelationKind: Send + Sync + 'static {}
11
12pub struct HasMany;
13impl RelationKind for HasMany {}
14
15pub struct HasOne;
16impl RelationKind for HasOne {}
17
18pub struct BelongsTo;
19impl RelationKind for BelongsTo {}
20
21pub struct BelongsToMany;
22impl RelationKind for BelongsToMany {}
23
24pub trait RelationDef: Send + Sync {
25    type Parent: Model;
26    type Child: Model;
27    type Kind: RelationKind;
28    fn local_key() -> &'static str;
29    fn foreign_key() -> &'static str;
30}
31
32/// A live relation handle: from a parent row, load the related rows.
33pub struct Relation<P: Model, C: Model, K: RelationKind> {
34    pub parent_value: serde_json::Value,
35    pub local_key: &'static str,
36    pub foreign_key: &'static str,
37    _marker: PhantomData<(P, C, K)>,
38}
39
40impl<P: Model, C: Model> Relation<P, C, HasMany> {
41    pub async fn load(self, pool: &Pool) -> Result<Vec<C>, Error> {
42        let sql = format!(
43            "SELECT {} FROM {} WHERE {} = $1",
44            C::COLUMNS.join(", "),
45            C::TABLE,
46            self.foreign_key,
47        );
48        let rows = sqlx::query_as::<_, C>(&sql)
49            .bind(self.parent_value)
50            .fetch_all(pool)
51            .await?;
52        Ok(rows)
53    }
54}
55
56impl<P: Model, C: Model> Relation<P, C, BelongsTo> {
57    pub async fn load(self, pool: &Pool) -> Result<Option<C>, Error> {
58        let sql = format!(
59            "SELECT {} FROM {} WHERE {} = $1 LIMIT 1",
60            C::COLUMNS.join(", "),
61            C::TABLE,
62            self.local_key,
63        );
64        let row = sqlx::query_as::<_, C>(&sql)
65            .bind(self.parent_value)
66            .fetch_optional(pool)
67            .await?;
68        Ok(row)
69    }
70}
71
72impl<P: Model, C: Model> Relation<P, C, HasOne> {
73    pub async fn load(self, pool: &Pool) -> Result<Option<C>, Error> {
74        let sql = format!(
75            "SELECT {} FROM {} WHERE {} = $1 LIMIT 1",
76            C::COLUMNS.join(", "),
77            C::TABLE,
78            self.foreign_key,
79        );
80        let row = sqlx::query_as::<_, C>(&sql)
81            .bind(self.parent_value)
82            .fetch_optional(pool)
83            .await?;
84        Ok(row)
85    }
86}
87
88impl<P: Model, C: Model, K: RelationKind> Relation<P, C, K> {
89    pub fn new(
90        parent_value: serde_json::Value,
91        local_key: &'static str,
92        foreign_key: &'static str,
93    ) -> Self {
94        Self {
95            parent_value,
96            local_key,
97            foreign_key,
98            _marker: PhantomData,
99        }
100    }
101}