ormlite_core/
join.rs

1use crate::model::Model;
2use serde::de::Error;
3use serde::Deserialize;
4use serde::{Serialize, Serializer};
5use sql::query::Criteria;
6use sql::query::SelectColumn;
7use sql::{Expr, Operation, Where};
8use sqlx::{Database, Decode, Encode, Type};
9use std::ops::{Deref, DerefMut};
10
11pub trait JoinMeta {
12    type IdType: Clone + Send + Eq + PartialEq + std::hash::Hash;
13    fn _id(&self) -> Self::IdType;
14}
15
16impl<T: JoinMeta> JoinMeta for Option<T> {
17    type IdType = Option<T::IdType>;
18
19    fn _id(&self) -> Self::IdType {
20        self.as_ref().map(|x| x._id())
21    }
22}
23
24impl<T: JoinMeta> JoinMeta for Join<T> {
25    type IdType = T::IdType;
26
27    fn _id(&self) -> Self::IdType {
28        self.id.clone()
29    }
30}
31
32// impl<T: JoinMeta> JoinMeta for Vec<T> {
33//     type IdType = T::IdType;
34
35//     fn _id(&self) -> Self::IdType {
36//         unimplemented!()
37//     }
38// }
39
40pub trait Loadable<DB, T: JoinMeta> {
41    #[allow(async_fn_in_trait)]
42    async fn load<'s, 'e, E>(&'s mut self, db: E) -> crate::error::Result<&'s T>
43    where
44        T::IdType: 'e + Send + Sync,
45        E: 'e + sqlx::Executor<'e, Database = DB>,
46        T: 's;
47}
48
49#[derive(Debug)]
50pub struct Join<T: JoinMeta> {
51    pub id: T::IdType,
52    data: JoinData<T>,
53}
54
55/// Only represents a many-to-one relationship.
56#[derive(Debug)]
57pub enum JoinData<T: JoinMeta> {
58    NotQueried,
59    QueryResult(T),
60    Modified(T),
61}
62
63impl<T: JoinMeta> Join<T> {
64    pub fn new_with_id(id: T::IdType) -> Self {
65        Self {
66            id,
67            data: JoinData::NotQueried,
68        }
69    }
70
71    pub fn new(obj: T) -> Self {
72        Self {
73            id: crate::join::JoinMeta::_id(&obj),
74            data: JoinData::Modified(obj),
75        }
76    }
77
78    /// Whether join data has been loaded into memory.
79    pub fn loaded(&self) -> bool {
80        match &self.data {
81            JoinData::NotQueried => false,
82            JoinData::QueryResult(_) => true,
83            JoinData::Modified(_) => true,
84        }
85    }
86
87    pub fn is_modified(&self) -> bool {
88        match &self.data {
89            JoinData::NotQueried => false,
90            JoinData::QueryResult(_) => false,
91            JoinData::Modified(_) => true,
92        }
93    }
94
95    /// Takes ownership and return any modified data. Leaves the Join in a NotQueried state.
96    #[doc(hidden)]
97    pub fn _take_modification(&mut self) -> Option<T> {
98        let owned = std::mem::replace(&mut self.data, JoinData::NotQueried);
99        match owned {
100            JoinData::NotQueried => None,
101            JoinData::QueryResult(_) => None,
102            JoinData::Modified(obj) => Some(obj),
103        }
104    }
105    fn transition_to_modified(&mut self) -> &mut T {
106        let owned = std::mem::replace(&mut self.data, JoinData::NotQueried);
107        match owned {
108            JoinData::NotQueried => {
109                panic!("Tried to deref_mut a joined object, but it has not been queried.")
110            }
111            JoinData::QueryResult(r) => {
112                self.data = JoinData::Modified(r);
113            }
114            JoinData::Modified(r) => {
115                self.data = JoinData::Modified(r);
116            }
117        }
118        match &mut self.data {
119            JoinData::Modified(r) => r,
120            _ => unreachable!(),
121        }
122    }
123
124    #[doc(hidden)]
125    pub fn _query_result(obj: T) -> Self {
126        Self {
127            id: obj._id(),
128            data: JoinData::QueryResult(obj),
129        }
130    }
131}
132
133impl<DB, T> Loadable<DB, T> for Join<T>
134where
135    DB: Database,
136    T: JoinMeta + Model<DB> + Send,
137    T::IdType: for<'a> Encode<'a, DB> + for<'a> Decode<'a, DB> + Type<DB>,
138{
139    async fn load<'s, 'e, E: sqlx::Executor<'e, Database = DB> + 'e>(
140        &'s mut self,
141        conn: E,
142    ) -> crate::error::Result<&'s T>
143    where
144        T::IdType: 'e + Send + Sync,
145        T: 's,
146    {
147        let model = T::fetch_one(self.id.clone(), conn).await?;
148        self.data = JoinData::QueryResult(model);
149        let s = &*self;
150        Ok(s.deref())
151    }
152}
153
154impl<T: JoinMeta> Deref for Join<T> {
155    type Target = T;
156
157    fn deref(&self) -> &Self::Target {
158        match &self.data {
159            JoinData::NotQueried => {
160                panic!("Tried to deref a joined object, but it has not been queried.")
161            }
162            JoinData::QueryResult(r) => r,
163            JoinData::Modified(r) => r,
164        }
165    }
166}
167
168impl<T: JoinMeta> DerefMut for Join<T> {
169    fn deref_mut(&mut self) -> &mut Self::Target {
170        self.transition_to_modified()
171    }
172}
173
174// #[derive(Debug, Copy, Clone)]
175// pub enum SemanticJoinType {
176//     OneToMany,
177//     ManyToOne,
178//     ManyToMany(&'static str),
179// }
180
181/// Not meant for end users.
182#[doc(hidden)]
183#[derive(Debug, Clone, Copy)]
184pub enum JoinDescription {
185    ManyToOne {
186        /// the columns of the joined table
187        columns: &'static [&'static str],
188        /// the name of the joined table
189        foreign_table: &'static str,
190
191        local_column: &'static str,
192        /// the field on the local object. joined table is aliased to this to prevent conflicts.
193        field: &'static str,
194        foreign_key: &'static str,
195    },
196}
197
198pub fn column_alias(field: &str, column: &str) -> String {
199    format!("__{}__{}", field, column)
200}
201
202pub fn select_columns(
203    columns: &'static [&'static str],
204    field: &'static str,
205) -> impl Iterator<Item = SelectColumn> + 'static {
206    columns
207        .iter()
208        .map(|&c| SelectColumn::table_column(field, c).alias(column_alias(field, c)))
209}
210
211pub fn criteria(local_table: &str, local_column: &str, remote_table: &str, remote_column: &str) -> Criteria {
212    Criteria::On(Where::Expr(Expr::BinOp(
213        Operation::Eq,
214        Expr::Column {
215            schema: None,
216            table: Some(local_table.to_string()),
217            column: local_column.to_string(),
218        }
219        .into(),
220        Expr::Column {
221            schema: None,
222            table: Some(remote_table.to_string()),
223            column: remote_column.to_string(),
224        }
225        .into(),
226    )))
227}
228
229// impl JoinDescription {
230// pub fn join_clause(&self, local_table: &str) -> JoinQueryFragment {
231//     use SemanticJoinType::*;
232//     let table = self.table_name;
233//     let relation = self.relation;
234//     let local_key = self.key;
235//     let foreign_key = self.foreign_key;
236//     let join = match &self.semantic_join_type {
237//         ManyToOne => {
238//             format!(r#""{relation}"."{foreign_key}" = "{local_table}"."{local_key}" "#)
239//         }
240//         OneToMany => {
241//             format!(r#""{relation}"."{local_key}" = "{local_table}"."{foreign_key}" "#)
242//         }
243//         ManyToMany(_join_table) => {
244//             unimplemented!()
245//         }
246//     };
247//     JoinQueryFragment::new(table).alias(self.relation).on_raw(join)
248// }
249
250// pub fn select_columns(&self) -> impl Iterator<Item = SelectColumn> + '_ {
251//     let JoinDescription::ManyToOne {
252//         columns,
253//         table,
254//         field,
255//         foreign_key,
256//     } = self
257//     else {
258//         panic!("ManyToMany not supported yet")
259//     };
260//     columns
261//         .iter()
262//         .map(|c| SelectColumn::table_column(field, c).alias(column_alias(field, column)))
263// }
264// }
265
266impl<T: JoinMeta + Serialize> Serialize for Join<T> {
267    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
268    where
269        S: Serializer,
270    {
271        match &self.data {
272            JoinData::Modified(data) => data.serialize(serializer),
273            JoinData::NotQueried => serializer.serialize_none(),
274            JoinData::QueryResult(data) => data.serialize(serializer),
275        }
276    }
277}
278
279impl<'de, T> Deserialize<'de> for Join<T>
280where
281    T: JoinMeta + Deserialize<'de>,
282{
283    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
284    where
285        D: serde::Deserializer<'de>,
286    {
287        let data = Option::<T>::deserialize(deserializer)?;
288
289        let (id_type, join_data) = match data {
290            Some(value) => (T::_id(&value), JoinData::QueryResult(value)),
291            None => return Err(D::Error::custom("Invalid value")),
292        };
293
294        Ok(Join {
295            id: id_type,
296            data: join_data,
297        })
298    }
299}