1use std::ops::{Deref, DerefMut};
2use async_trait::async_trait;
3use serde::{Serialize, Serializer};
4use sqlmo::query::{Join as JoinQueryFragment};
5use sqlmo::query::SelectColumn;
6use sqlx::{Database, Decode, Encode, Type};
7use crate::model::Model;
8
9pub trait JoinMeta {
10 type IdType: Clone + Send + Eq + PartialEq + std::hash::Hash;
11 fn _id(&self) -> Self::IdType;
12}
13
14impl<T: JoinMeta> JoinMeta for Option<T> {
15 type IdType = Option<T::IdType>;
16
17 fn _id(&self) -> Self::IdType {
18 self.as_ref().map(|x| x._id())
19 }
20}
21
22impl<T: JoinMeta> JoinMeta for Join<T> {
23 type IdType = T::IdType;
24
25 fn _id(&self) -> Self::IdType {
26 self.id.clone()
27 }
28}
29
30#[async_trait]
31pub trait Loadable<DB, T: JoinMeta> {
32 async fn load<'s, 'e, E>(&'s mut self, db: E) -> crate::error::Result<&'s T>
33 where
34 T::IdType: 'e + Send + Sync,
35 E: 'e + sqlx::Executor<'e, Database=DB>;
36}
37
38#[derive(Debug)]
39pub struct Join<T: JoinMeta> {
40 pub id: T::IdType,
41 data: JoinData<T>,
42}
43
44#[derive(Debug)]
46pub enum JoinData<T: JoinMeta> {
47 NotQueried,
48 QueryResult(T),
49 Modified(T),
50}
51
52
53impl<T: JoinMeta> Join<T> {
54 pub fn new_with_id(id: T::IdType) -> Self {
55 Self {
56 id,
57 data: JoinData::NotQueried,
58 }
59 }
60
61 pub fn new(obj: T) -> Self {
62 Self {
63 id: obj._id(),
64 data: JoinData::Modified(obj),
65 }
66 }
67
68 pub fn loaded(&self) -> bool {
70 match &self.data {
71 JoinData::NotQueried => false,
72 JoinData::QueryResult(_) => true,
73 JoinData::Modified(_) => true,
74 }
75 }
76
77 pub fn is_modified(&self) -> bool {
78 match &self.data {
79 JoinData::NotQueried => false,
80 JoinData::QueryResult(_) => false,
81 JoinData::Modified(_) => true,
82 }
83 }
84
85 #[doc(hidden)]
87 pub fn _take_modification(&mut self) -> Option<T> {
88 let owned = std::mem::replace(&mut self.data, JoinData::NotQueried);
89 match owned {
90 JoinData::NotQueried => None,
91 JoinData::QueryResult(_) => None,
92 JoinData::Modified(obj) => {
93 Some(obj)
94 }
95 }
96 }
97 fn transition_to_modified(&mut self) -> &mut T {
98 let owned = std::mem::replace(&mut self.data, JoinData::NotQueried);
99 match owned {
100 JoinData::NotQueried => panic!("Tried to deref_mut a joined object, but it has not been queried."),
101 JoinData::QueryResult(r) => {
102 self.data = JoinData::Modified(r);
103 }
104 JoinData::Modified(r) => {
105 self.data = JoinData::Modified(r);
106 }
107 }
108 match &mut self.data {
109 JoinData::Modified(r) => r,
110 _ => unreachable!(),
111 }
112 }
113
114 #[doc(hidden)]
115 pub fn _query_result(obj: T) -> Self {
116 Self {
117 id: obj._id(),
118 data: JoinData::QueryResult(obj),
119 }
120 }
121}
122
123#[async_trait]
124impl<DB, T> Loadable<DB, T> for Join<T>
125 where
126 DB: Database,
127 T: JoinMeta + Model<DB> + Send,
128 T::IdType: for<'a> Encode<'a, DB> + for<'a> Decode<'a, DB> + Type<DB>,
129{
130 async fn load<'s, 'e, E: sqlx::Executor<'e, Database=DB> + 'e>(&'s mut self, conn: E) -> crate::error::Result<&'s T>
131 where
132 T::IdType: 'e + Send + Sync,
133 {
134 let model = T::fetch_one(self.id.clone(), conn).await?;
135 self.data = JoinData::QueryResult(model);
136 let s = &*self;
137 Ok(s.deref())
138 }
139}
140
141impl<T: JoinMeta> Deref for Join<T> {
142 type Target = T;
143
144 fn deref(&self) -> &Self::Target {
145 match &self.data {
146 JoinData::NotQueried => panic!("Tried to deref a joined object, but it has not been queried."),
147 JoinData::QueryResult(r) => r,
148 JoinData::Modified(r) => r,
149 }
150 }
151}
152
153impl<T: JoinMeta> DerefMut for Join<T> {
154 fn deref_mut(&mut self) -> &mut Self::Target {
155 self.transition_to_modified()
156 }
157}
158
159#[derive(Debug, Copy, Clone)]
160pub enum SemanticJoinType {
161 OneToMany,
162 ManyToOne,
163 ManyToMany(&'static str),
164}
165
166#[doc(hidden)]
168#[derive(Debug, Clone, Copy)]
169pub struct JoinDescription {
170 pub joined_columns: &'static [&'static str],
171 pub table_name: &'static str,
172 pub relation: &'static str,
173 pub key: &'static str,
175 pub foreign_key: &'static str,
176 pub semantic_join_type: SemanticJoinType,
177}
178
179impl JoinDescription {
180 pub fn to_join_clause(&self, local_table: &str) -> JoinQueryFragment {
181 use SemanticJoinType::*;
182 let table = self.table_name;
183 let relation = self.relation;
184 let local_key = self.key;
185 let foreign_key = self.foreign_key;
186 let join = match &self.semantic_join_type {
187 ManyToOne => {
188 format!(r#""{relation}"."{foreign_key}" = "{local_table}"."{local_key}" "#)
189 }
190 OneToMany => {
191 format!(r#""{relation}"."{local_key}" = "{local_table}"."{foreign_key}" "#)
192 }
193 ManyToMany(_join_table) => {
194 unimplemented!()
195 }
196 };
197 JoinQueryFragment::new(table)
198 .alias(self.relation)
199 .on_raw(join)
200 }
201
202 pub fn select_clause(&self) -> impl Iterator<Item=SelectColumn> + '_ {
203 self.joined_columns.iter()
204 .map(|c| SelectColumn::table_column(self.relation, c)
205 .alias(self.alias(c)))
206 }
207
208 pub fn alias(&self, column: &str) -> String {
209 format!("__{}__{}", self.relation, column)
210 }
211}
212
213impl<T: JoinMeta + Serialize> Serialize for Join<T> {
214 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
215 where
216 S: Serializer {
217 match &self.data {
218 JoinData::Modified(data) => data.serialize(serializer),
219 JoinData::NotQueried => serializer.serialize_none(),
220 JoinData::QueryResult(data) => data.serialize(serializer),
221 }
222 }
223}