luna_orm/
sql_generator.rs

1//use async_trait::async_trait;
2use luna_orm_trait::FromClause;
3use luna_orm_trait::JoinedConditions;
4use luna_orm_trait::{Entity, Location, Mutation, OrderBy, Pagination, Primary, Selection};
5
6#[derive(Default, Debug, Clone)]
7pub struct DefaultSqlGenerator {}
8impl DefaultSqlGenerator {
9    pub fn new() -> Self {
10        Self {}
11    }
12}
13impl SqlGenerator for DefaultSqlGenerator {}
14
15#[derive(Default, Debug, Clone)]
16pub struct MySqlGenerator {}
17impl MySqlGenerator {
18    pub fn new() -> Self {
19        Self {}
20    }
21}
22impl SqlGenerator for MySqlGenerator {
23    fn get_upsert_sql(&self, entity: &dyn Entity) -> String {
24        let table_name = entity.get_table_name();
25        let field_names = entity.get_insert_fields();
26        let fields = wrap_fields(&field_names, self.get_wrap_char());
27        let marks = generate_question_mark_list(&field_names);
28        let set_field_names = entity.get_upsert_set_fields();
29        let assign_clause = wrap_locate_fields(
30            &set_field_names,
31            self.get_wrap_char(),
32            self.get_place_holder(),
33        );
34
35        let upsert_sql = format!(
36            "INSERT INTO {}{}{} ({}) VALUES({})
37            ON DUPLICATE KEY UPDATE SET {}",
38            self.get_wrap_char(),
39            table_name,
40            self.get_wrap_char(),
41            fields,
42            marks,
43            assign_clause
44        )
45        .to_string();
46        self.post_process(upsert_sql)
47    }
48
49    fn get_create_sql(&self, entity: &dyn Entity) -> String {
50        let table_name = entity.get_table_name();
51        let field_names = entity.get_insert_fields();
52        let fields = wrap_fields(&field_names, self.get_wrap_char());
53        let marks = generate_question_mark_list(&field_names);
54        let insert_sql = format!(
55            "INSERT INTO {}{}{} ({}) VALUES({})",
56            self.get_wrap_char(),
57            table_name,
58            self.get_wrap_char(),
59            fields,
60            marks
61        )
62        .to_string();
63        self.post_process(insert_sql)
64    }
65}
66
67#[derive(Default, Debug, Clone)]
68pub struct PostgresGenerator {}
69impl PostgresGenerator {
70    pub fn new() -> Self {
71        Self {}
72    }
73}
74impl SqlGenerator for PostgresGenerator {
75    fn post_process(&self, origin: String) -> String {
76        self.pg_post_process(origin)
77    }
78}
79
80pub trait SqlGenerator {
81    // const WRAP_CHAR: char = '`'; can not made trait to trait object
82    #[inline(always)]
83    fn get_wrap_char(&self) -> char {
84        '`'
85    }
86
87    // const PLACE_HOLDER: char = '?'; can not made trait to trait object
88    #[inline(always)]
89    fn get_place_holder(&self) -> char {
90        '?'
91    }
92
93    #[inline]
94    fn pg_post_process(&self, origin_sql: String) -> String {
95        origin_sql
96            .chars()
97            .enumerate()
98            .map(|(i, c)| match c {
99                '?' => format!("${}", i + 1),
100                _ => c.to_string(),
101            })
102            .collect()
103    }
104
105    #[inline(always)]
106    fn post_process(&self, origin: String) -> String {
107        origin
108    }
109
110    fn get_last_row_id_sql(&self) -> &'static str {
111        "SELECT last_insert_rowid() as `last_row_id`"
112    }
113
114    fn get_select_sql(&self, selection: &dyn Selection, primay: &dyn Primary) -> String {
115        let table_name = primay.get_table_name();
116        let selected_fields: Vec<String> = selection.get_selected_fields();
117        let select_clause = wrap_fields(&selected_fields, self.get_wrap_char());
118        let located_fields = primay.get_primary_field_names();
119        let where_clause = wrap_locate_str_fields(
120            located_fields,
121            self.get_wrap_char(),
122            self.get_place_holder(),
123        );
124        let select_sql = format!(
125            "SELECT {} FROM {}{}{} WHERE {}",
126            select_clause,
127            self.get_wrap_char(),
128            table_name,
129            self.get_wrap_char(),
130            where_clause
131        )
132        .to_string();
133        self.post_process(select_sql)
134    }
135
136    fn get_search_count_sql(&self, location: &dyn Location) -> String {
137        let table_name = location.get_table_name();
138        let where_clause = location.get_where_clause(self.get_wrap_char(), self.get_place_holder());
139
140        let select_sql = format!(
141            "SELECT COUNT(1) AS {}count{} FROM {}{}{} WHERE {}",
142            self.get_wrap_char(),
143            self.get_wrap_char(),
144            self.get_wrap_char(),
145            table_name,
146            self.get_wrap_char(),
147            where_clause
148        )
149        .to_string();
150        self.post_process(select_sql)
151    }
152
153    fn get_search_all_sql(&self, selection: &dyn Selection) -> String {
154        let table_name = selection.get_table_name();
155        let selected_field_names = selection.get_selected_fields();
156        let selected_fields = wrap_fields(&selected_field_names, self.get_wrap_char());
157        let select_sql = format!(
158            "SELECT {} FROM {}{}{}",
159            selected_fields,
160            self.get_wrap_char(),
161            table_name,
162            self.get_wrap_char(),
163        )
164        .to_string();
165        self.post_process(select_sql)
166    }
167
168    fn get_search_sql(
169        &self,
170        selection: &dyn Selection,
171        location: &dyn Location,
172        order_by: Option<&dyn OrderBy>,
173    ) -> String {
174        let selected_field_names = selection.get_selected_fields();
175        let selected_fields = wrap_fields(&selected_field_names, self.get_wrap_char());
176        let table_name = location.get_table_name();
177        let where_clause = location.get_where_clause(self.get_wrap_char(), self.get_place_holder());
178        if order_by.is_none() {
179            let select_sql = format!(
180                "SELECT {} FROM {}{}{} WHERE {}",
181                selected_fields,
182                self.get_wrap_char(),
183                table_name,
184                self.get_wrap_char(),
185                where_clause
186            )
187            .to_string();
188            self.post_process(select_sql)
189        } else {
190            let order_by_field_names = order_by.unwrap().get_order_by_fields();
191            let order_by_fields = wrap_str_fields(&order_by_field_names, self.get_wrap_char());
192            let select_sql = format!(
193                "SELECT {} FROM {}{}{} WHERE {} ORDER BY {}",
194                selected_fields,
195                self.get_wrap_char(),
196                table_name,
197                self.get_wrap_char(),
198                where_clause,
199                order_by_fields
200            )
201            .to_string();
202            self.post_process(select_sql)
203        }
204    }
205
206    fn get_limit_sql(&self, page: &Pagination) -> String {
207        let offset = page.page_size * page.page_num;
208        let count = page.page_size;
209        return format!("{}, {}", offset, count);
210    }
211
212    fn get_paged_search_sql(
213        &self,
214        selection: &dyn Selection,
215        location: &dyn Location,
216        order_by: Option<&dyn OrderBy>,
217        page: &Pagination,
218    ) -> String {
219        let selected_field_names = selection.get_selected_fields();
220        let selected_fields = wrap_fields(&selected_field_names, self.get_wrap_char());
221        let table_name = location.get_table_name();
222        let where_clause = location.get_where_clause(self.get_wrap_char(), self.get_place_holder());
223        let offset = page.page_size * page.page_num;
224        let count = page.page_size;
225        if order_by.is_none() {
226            let select_sql = format!(
227                "SELECT {} FROM {}{}{} WHERE {} LIMIT {},{}",
228                selected_fields,
229                self.get_wrap_char(),
230                table_name,
231                self.get_wrap_char(),
232                where_clause,
233                offset,
234                count
235            )
236            .to_string();
237            self.post_process(select_sql)
238        } else {
239            let order_by_field_names = order_by.unwrap().get_order_by_fields();
240            let order_by_fields = wrap_str_fields(order_by_field_names, self.get_wrap_char());
241            let select_sql = format!(
242                "SELECT {} FROM {}{}{} WHERE {} ORDER BY {} LIMIT {},{}",
243                selected_fields,
244                self.get_wrap_char(),
245                table_name,
246                self.get_wrap_char(),
247                where_clause,
248                order_by_fields,
249                offset,
250                count
251            )
252            .to_string();
253            self.post_process(select_sql)
254        }
255    }
256
257    fn get_page_joined_search_sql(
258        &self,
259        joined_conds: &JoinedConditions,
260        locations: &Vec<&dyn Location>,
261        order_by: Option<&dyn OrderBy>,
262        selections: &Vec<&dyn Selection>,
263        page: &Pagination,
264    ) -> String {
265        let mut selected_field_names: Vec<String> = Vec::new();
266        for selection in selections {
267            let fields = selection.get_selected_fields();
268            selected_field_names.extend(fields);
269        }
270        let selected_fields = wrap_fields(&selected_field_names, self.get_wrap_char());
271
272        let mut location_stmts: Vec<String> = Vec::new();
273        for location in locations {
274            let where_clause =
275                location.get_where_clause(self.get_wrap_char(), self.get_place_holder());
276            location_stmts.push(where_clause);
277        }
278        let where_clause = location_stmts.join(",");
279        let from_clause = joined_conds.get_from_clause();
280        let sql: String = format!(
281            "SELECT {} FROM {} WHERE {}",
282            selected_fields, from_clause, where_clause
283        )
284        .to_string();
285        self.post_process(sql)
286    }
287
288    fn get_insert_sql(&self, entity: &dyn Entity) -> String {
289        let table_name = entity.get_table_name();
290        let field_names = entity.get_insert_fields();
291        let fields = wrap_fields(&field_names, self.get_wrap_char());
292        let marks = generate_question_mark_list(&field_names);
293        let insert_sql = format!(
294            "INSERT INTO {}{}{} ({}) VALUES({})",
295            self.get_wrap_char(),
296            table_name,
297            self.get_wrap_char(),
298            fields,
299            marks
300        )
301        .to_string();
302        self.post_process(insert_sql)
303    }
304
305    fn get_create_sql(&self, entity: &dyn Entity) -> String {
306        let table_name = entity.get_table_name();
307        let field_names = entity.get_insert_fields();
308        let fields = wrap_fields(&field_names, self.get_wrap_char());
309        let marks = generate_question_mark_list(&field_names);
310        let auto_field_name = entity.get_auto_increment_field();
311        let create_sql = if auto_field_name.is_some() {
312            let auto_field_name = auto_field_name.unwrap();
313            format!(
314                "INSERT INTO {}{}{} ({}) VALUES({}) RETURNING {}{}{} AS last_row_id",
315                self.get_wrap_char(),
316                table_name,
317                self.get_wrap_char(),
318                fields,
319                marks,
320                self.get_wrap_char(),
321                auto_field_name,
322                self.get_wrap_char()
323            )
324            .to_string()
325        } else {
326            format!(
327                "INSERT INTO {}{}{} ({}) VALUES({})",
328                self.get_wrap_char(),
329                table_name,
330                self.get_wrap_char(),
331                fields,
332                marks
333            )
334            .to_string()
335        };
336        self.post_process(create_sql)
337    }
338
339    fn get_upsert_sql(&self, entity: &dyn Entity) -> String {
340        let table_name = entity.get_table_name();
341
342        let field_names = entity.get_insert_fields();
343        let fields = wrap_fields(&field_names, self.get_wrap_char());
344        let marks = generate_question_mark_list(&field_names);
345        let set_field_names = entity.get_upsert_set_fields();
346        let assign_clause = wrap_locate_fields(
347            &set_field_names,
348            self.get_wrap_char(),
349            self.get_place_holder(),
350        );
351
352        let upsert_sql = format!(
353            "INSERT INTO {}{}{} ({}) VALUES({})
354            ON CONFLICT DO UPDATE SET {}",
355            self.get_wrap_char(),
356            table_name,
357            self.get_wrap_char(),
358            fields,
359            marks,
360            assign_clause
361        )
362        .to_string();
363        self.post_process(upsert_sql)
364    }
365
366    fn get_update_sql(&self, mutation: &dyn Mutation, primary: &dyn Primary) -> String {
367        let table_name = primary.get_table_name();
368        let body_field_names = mutation.get_fields_name();
369        let body_fields = wrap_locate_fields(
370            &body_field_names,
371            self.get_wrap_char(),
372            self.get_place_holder(),
373        );
374        let primary_field_names = primary.get_primary_field_names();
375        let primary_fields = wrap_locate_str_fields(
376            &primary_field_names,
377            self.get_wrap_char(),
378            self.get_place_holder(),
379        );
380        let update_sql = format!(
381            "UPDATE {}{}{} SET {} WHERE {}",
382            self.get_wrap_char(),
383            table_name,
384            self.get_wrap_char(),
385            body_fields,
386            primary_fields
387        )
388        .to_string();
389        self.post_process(update_sql)
390    }
391
392    fn get_change_sql(&self, mutation: &dyn Mutation, location: &dyn Location) -> String {
393        let table_name = location.get_table_name();
394        let mutation_fields = mutation.get_fields_name();
395        let update_clause = wrap_locate_fields(
396            &mutation_fields,
397            self.get_wrap_char(),
398            self.get_place_holder(),
399        );
400
401        let where_clause = location.get_where_clause(self.get_wrap_char(), self.get_place_holder());
402        let update_sql = format!(
403            "UPDATE {}{}{} SET {} WHERE {}",
404            self.get_wrap_char(),
405            table_name,
406            self.get_wrap_char(),
407            update_clause,
408            where_clause
409        )
410        .to_string();
411        self.post_process(update_sql)
412    }
413
414    fn get_delete_sql(&self, primary: &dyn Primary) -> String {
415        let table_name = primary.get_table_name();
416        let field_names = primary.get_primary_field_names();
417        let where_clause =
418            wrap_locate_str_fields(field_names, self.get_wrap_char(), self.get_place_holder());
419        let delete_sql = format!(
420            "DELETE FROM {}{}{} WHERE {}",
421            self.get_wrap_char(),
422            table_name,
423            self.get_wrap_char(),
424            where_clause
425        )
426        .to_string();
427        self.post_process(delete_sql)
428    }
429
430    fn get_purify_sql(&self, location: &dyn Location) -> String {
431        let table_name = location.get_table_name();
432        let where_clause = location.get_where_clause(self.get_wrap_char(), self.get_place_holder());
433        let delete_sql = format!(
434            "DELETE FROM {}{}{} WHERE {}",
435            self.get_wrap_char(),
436            table_name,
437            self.get_wrap_char(),
438            where_clause
439        )
440        .to_string();
441        self.post_process(delete_sql)
442    }
443}
444#[inline]
445fn wrap_fields(fields: &[String], wrap_char: char) -> String {
446    fields
447        .iter()
448        .map(|e| format!("{}{}{}", wrap_char, e, wrap_char))
449        .collect::<Vec<String>>()
450        .join(",")
451}
452
453#[inline]
454fn wrap_locate_fields(fields: &[String], wrap_char: char, place_holder: char) -> String {
455    fields
456        .iter()
457        .map(|e| format!("{}{}{} = {}", wrap_char, e, wrap_char, place_holder))
458        .collect::<Vec<String>>()
459        .join(",")
460}
461
462#[inline]
463fn wrap_str_fields(fields: &[&str], wrap_char: char) -> String {
464    fields
465        .iter()
466        .map(|e| format!("{}{}{}", wrap_char, e, wrap_char))
467        .collect::<Vec<String>>()
468        .join(",")
469}
470
471#[inline]
472fn wrap_locate_str_fields(fields: &[&str], wrap_char: char, place_holder: char) -> String {
473    fields
474        .iter()
475        .map(|e| format!("{}{}{} = {}", wrap_char, e, wrap_char, place_holder))
476        .collect::<Vec<String>>()
477        .join(",")
478}
479
480#[inline]
481fn wrap_pg_locate_str_fields(fields: &[&str], wrap_char: char) -> String {
482    fields
483        .iter()
484        .enumerate()
485        .map(|(i, e)| format!("{}{}{} = ${}", wrap_char, e, wrap_char, i + 1))
486        .collect::<Vec<String>>()
487        .join(",")
488}
489
490#[inline]
491fn generate_question_marks(fields: &[&str]) -> String {
492    fields
493        .iter()
494        .map(|_| "?".to_string())
495        .collect::<Vec<String>>()
496        .join(", ")
497}
498#[inline]
499fn generate_question_mark_list(fields: &[String]) -> String {
500    fields
501        .iter()
502        .map(|_| "?".to_string())
503        .collect::<Vec<String>>()
504        .join(", ")
505}