luna_orm_trait/
location.rs

1use super::{Location, Selection};
2use serde::{Deserialize, Serialize};
3use serde_with::serde_as;
4pub trait SelectionWithSend: Selection + Send {}
5impl<T> SelectionWithSend for T where T: Selection + Send {}
6
7pub trait LocationWithSend: Location + Send {}
8impl<T> LocationWithSend for T where T: Location + Send {}
9
10#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
11pub struct LocationExpr<T> {
12    pub val: T,
13    pub cmp: CmpOperator,
14}
15
16impl<T> LocationExpr<T> {
17    pub fn new(cmp: CmpOperator, val: T) -> Self {
18        Self { cmp, val }
19    }
20}
21
22#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
23pub struct SelectedLocationExpr<T> {
24    pub selected: bool,
25    pub val: T,
26    pub cmp: CmpOperator,
27}
28
29#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
30pub enum CmpOperator {
31    #[serde(alias = "=")]
32    Eq,
33    #[serde(alias = "<")]
34    LessThan,
35    #[serde(alias = "<=")]
36    LessOrEq,
37    #[serde(alias = ">")]
38    GreaterThan,
39    #[serde(alias = ">=")]
40    GreaterOrEq,
41    #[serde(alias = "like")]
42    Like,
43}
44
45impl CmpOperator {
46    pub fn get_sql(&self) -> &'static str {
47        match self {
48            CmpOperator::Eq => "=",
49            CmpOperator::LessThan => "<",
50            CmpOperator::LessOrEq => "<=",
51            CmpOperator::GreaterThan => ">",
52            CmpOperator::GreaterOrEq => ">=",
53            CmpOperator::Like => "LIKE",
54        }
55    }
56}
57
58/*
59#[derive(Deserialize)]
60pub struct LocationQuery<S, L>
61where
62    S: Selection + Send,
63    L: Location + Send,
64{
65    selection: S,
66    location: L,
67}
68*/
69
70#[typetag::serde(tag = "table")]
71pub trait LocatedQuery {
72    fn get_selection(&self) -> &dyn Selection;
73    fn get_location(&self) -> &dyn Location;
74}
75
76/*
77impl<S, L> LocatedQuery for LocationQuery<S, L>
78where
79    S: Selection + Send,
80    L: Location + Send,
81{
82    fn get_selection(&self) -> &dyn SelectionWithSend {
83        &self.selection
84    }
85    fn get_location(&self) -> &dyn LocationWithSend {
86        &self.location
87    }
88}
89*/
90
91#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
92pub enum JoinMode {
93    #[serde(alias = "left")]
94    Left,
95    #[serde(alias = "right")]
96    Right,
97    #[serde(alias = "outer")]
98    Outer,
99    #[serde(alias = "inner")]
100    Inner,
101}
102
103impl JoinMode {
104    pub fn get_join_operator(&self) -> &'static str {
105        match self {
106            JoinMode::Left => "LEFT JOIN",
107            JoinMode::Right => "RIGHT JOIN",
108            JoinMode::Outer => "OUTER JOIN",
109            JoinMode::Inner => "INNER JOIN",
110        }
111    }
112}
113
114#[derive(Serialize, PartialEq, Eq, Debug)]
115pub struct JoinedField {
116    table_name: String,
117    field_name: String,
118}
119
120impl<'de> serde::de::Deserialize<'de> for JoinedField {
121    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
122    where
123        D: serde::Deserializer<'de>,
124    {
125        let content = String::deserialize(deserializer)?;
126        let pair: Vec<&str> = content.split('.').collect();
127        if pair.len() != 2 {
128            return Err(serde::de::Error::custom(
129                "join field must have table name, and seperate by '.' ",
130            ));
131        }
132        Ok(Self {
133            table_name: pair.first().unwrap().to_string(),
134            field_name: pair.last().unwrap().to_string(),
135        })
136    }
137}
138
139pub type JoinedFields = (JoinedField, JoinedField);
140
141#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
142pub struct JoinedCondition {
143    mode: JoinMode,
144    #[serde(alias = "left")]
145    left_table: String,
146    #[serde(alias = "right")]
147    right_table: String,
148    #[serde(alias = "fields")]
149    joined_fields: Vec<JoinedFields>,
150}
151
152fn get_on_clause(joined_fields: &Vec<JoinedFields>) -> String {
153    let mut on_clause_vec: Vec<String> = Vec::new();
154    for field in joined_fields {
155        let on_seg = format!(
156            "{}.{} = {}.{}",
157            field.0.table_name, field.0.field_name, field.1.table_name, field.1.field_name
158        );
159        on_clause_vec.push(on_seg);
160    }
161    on_clause_vec.join(",")
162}
163
164pub trait FromClause {
165    fn get_from_clause(&self) -> String;
166}
167
168impl FromClause for JoinedCondition {
169    fn get_from_clause(&self) -> String {
170        let on_clause = get_on_clause(&self.joined_fields);
171        let join_operator = self.mode.get_join_operator();
172        format!(
173            "{} {} {} ON {}",
174            self.left_table, join_operator, self.right_table, on_clause
175        )
176    }
177}
178
179#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
180pub struct JoinedConditionPart {
181    mode: JoinMode,
182    table: String,
183    #[serde(alias = "fields")]
184    joined_fields: Vec<JoinedFields>,
185}
186
187impl FromClause for JoinedConditionPart {
188    fn get_from_clause(&self) -> String {
189        let on_clause = get_on_clause(&self.joined_fields);
190        let join_operator = self.mode.get_join_operator();
191        format!("{} {} ON {}", join_operator, self.table, on_clause)
192    }
193}
194
195#[serde_as]
196#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
197pub struct JoinedConditionArray<const N: usize> {
198    root: JoinedCondition,
199    #[serde_as(as = "[_; N]")]
200    next: [JoinedConditionPart; N],
201}
202
203impl<const N: usize> JoinedConditionArray<N> {
204    pub fn get_from_clause(&self) -> String {
205        let root_join = self.root.get_from_clause();
206        let mut part_clauses: Vec<String> = Vec::new();
207        for part in &self.next {
208            let part_clause = part.get_from_clause();
209            part_clauses.push(part_clause);
210        }
211        let part_clause = part_clauses.join(" ");
212        format!("{} {}", root_join, part_clause)
213    }
214}
215
216#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
217#[serde(untagged)]
218pub enum JoinedConditions {
219    Two(JoinedCondition),
220    Three(JoinedConditionArray<1>),
221    Four(JoinedConditionArray<2>),
222    Five(JoinedConditionArray<3>),
223    Six(JoinedConditionArray<4>),
224    Seven(JoinedConditionArray<5>),
225    Eight(JoinedConditionArray<6>),
226    Nine(JoinedConditionArray<7>),
227    Ten(JoinedConditionArray<8>),
228    Eleven(JoinedConditionArray<9>),
229    Twelve(JoinedConditionArray<10>),
230}
231
232impl FromClause for JoinedConditions {
233    fn get_from_clause(&self) -> String {
234        match &self {
235            JoinedConditions::Two(e) => e.get_from_clause(),
236            JoinedConditions::Three(e) => e.get_from_clause(),
237            JoinedConditions::Four(e) => e.get_from_clause(),
238            JoinedConditions::Five(e) => e.get_from_clause(),
239            JoinedConditions::Six(e) => e.get_from_clause(),
240            JoinedConditions::Seven(e) => e.get_from_clause(),
241            JoinedConditions::Eight(e) => e.get_from_clause(),
242            JoinedConditions::Nine(e) => e.get_from_clause(),
243            JoinedConditions::Ten(e) => e.get_from_clause(),
244            JoinedConditions::Eleven(e) => e.get_from_clause(),
245            JoinedConditions::Twelve(e) => e.get_from_clause(),
246        }
247    }
248}
249
250#[derive(Serialize, Deserialize)]
251pub struct JoinedQuery {
252    query_vec: Vec<Box<dyn LocatedQuery>>,
253    join_conditions: JoinedConditions,
254}
255
256#[cfg(test)]
257mod test {
258    use crate::location::JoinMode;
259
260    use super::FromClause;
261    use super::JoinedCondition;
262    use super::JoinedConditions;
263    use serde_json;
264
265    #[test]
266    pub fn test_joined_condition() {
267        let content = r#"{ "mode": "inner", "left": "user", "right": "class", "fields": [ ["user.id", "class.id"] ]}"#;
268        let joined_cond: JoinedCondition = serde_json::from_str(content).unwrap();
269        assert_eq!(joined_cond.mode, JoinMode::Inner);
270        assert_eq!(joined_cond.left_table, "user");
271        assert_eq!(joined_cond.right_table, "class");
272        assert_eq!(joined_cond.joined_fields[0].0.table_name, "user");
273        assert_eq!(joined_cond.joined_fields[0].1.table_name, "class");
274        assert_eq!(joined_cond.joined_fields[0].0.field_name, "id");
275        assert_eq!(joined_cond.joined_fields[0].1.field_name, "id");
276    }
277
278    #[test]
279    pub fn test_joined_conditions() {
280        let content = r#" {"root": { "mode": "inner", "left": "user", "right": "class", "fields": [ ["user.id", "class.id"] ]}, "next":[
281            {"mode": "outer", "table": "school", "fields": [["school.id", "user.id"], ["user.name", "school.name"] ] },
282            {"mode": "outer", "table": "country",   "fields": [["country.id", "school.id"], ["coutry.name", "user.name"] ]}
283        ] }"#;
284        let joined_conds: JoinedConditions = serde_json::from_str(content).unwrap();
285        match joined_conds {
286            JoinedConditions::Four(_) => {}
287            _ => panic!("deserialize wrong"),
288        }
289
290        let from_clause = joined_conds.get_from_clause();
291        assert_eq!(from_clause, "user INNER JOIN class ON user.id = class.id OUTER JOIN school ON school.id = user.id,user.name = school.name OUTER JOIN country ON country.id = school.id,coutry.name = user.name");
292    }
293}