Skip to main content

dbkit_core/
query.rs

1use std::marker::PhantomData;
2
3use crate::compile::{CompiledSql, SqlBuilder, ToSql};
4use crate::expr::{Expr, ExprNode, IntoExpr};
5use crate::load::{ApplyLoad, LoadChain, NoLoad};
6use crate::rel::RelationInfo;
7use crate::schema::{ColumnRef, Table};
8
9#[derive(Debug, Clone, Copy)]
10pub enum JoinKind {
11    Inner,
12    Left,
13}
14
15#[derive(Debug, Clone)]
16pub struct Join {
17    pub table: Table,
18    pub on: Expr<bool>,
19    pub kind: JoinKind,
20}
21
22#[derive(Debug, Clone)]
23pub struct SelectItem {
24    pub expr: ExprNode,
25    pub alias: Option<String>,
26}
27
28#[derive(Debug, Clone, Copy)]
29pub enum OrderDirection {
30    Asc,
31    Desc,
32}
33
34#[derive(Debug, Clone)]
35pub enum OrderExpr {
36    Expr(ExprNode),
37    Alias(String),
38}
39
40pub trait IntoOrderExpr {
41    fn into_order_expr(self) -> OrderExpr;
42}
43
44impl IntoOrderExpr for ColumnRef {
45    fn into_order_expr(self) -> OrderExpr {
46        OrderExpr::Expr(ExprNode::Column(self))
47    }
48}
49
50impl<M, T> IntoOrderExpr for crate::schema::Column<M, T> {
51    fn into_order_expr(self) -> OrderExpr {
52        OrderExpr::Expr(ExprNode::Column(self.as_ref()))
53    }
54}
55
56impl<T> IntoOrderExpr for Expr<T> {
57    fn into_order_expr(self) -> OrderExpr {
58        OrderExpr::Expr(self.node)
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct Order {
64    pub expr: OrderExpr,
65    pub direction: OrderDirection,
66}
67
68#[derive(Debug, Clone, Copy, Default)]
69pub struct NoRowLock;
70
71#[derive(Debug, Clone, Copy, Default)]
72pub struct ForUpdateRowLock;
73
74#[derive(Debug, Clone, Copy, Default)]
75pub struct NotDistinct;
76
77#[derive(Debug, Clone, Copy, Default)]
78pub struct DistinctSelected;
79
80#[derive(Debug, Clone, Copy, Default)]
81pub struct NotGrouped;
82
83#[derive(Debug, Clone, Copy, Default)]
84pub struct Grouped;
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87enum RowLockWait {
88    Wait,
89    SkipLocked,
90    NoWait,
91}
92
93#[derive(Debug, Clone)]
94pub struct Select<Out, Loads = NoLoad, Lock = NoRowLock, DistinctState = NotDistinct, GroupState = NotGrouped> {
95    table: Table,
96    columns: Option<Vec<SelectItem>>,
97    joins: Vec<Join>,
98    filters: Vec<Expr<bool>>,
99    group_by: Vec<ExprNode>,
100    having: Vec<Expr<bool>>,
101    order_by: Vec<Order>,
102    limit: Option<u64>,
103    offset: Option<u64>,
104    distinct: bool,
105    row_lock_wait: Option<RowLockWait>,
106    loads: Loads,
107    _marker: PhantomData<Out>,
108    _lock_marker: PhantomData<Lock>,
109    _distinct_marker: PhantomData<DistinctState>,
110    _group_marker: PhantomData<GroupState>,
111}
112
113impl<Out> Select<Out, NoLoad, NoRowLock, NotDistinct, NotGrouped> {
114    pub fn new(table: Table) -> Self {
115        Self {
116            table,
117            columns: None,
118            joins: Vec::new(),
119            filters: Vec::new(),
120            group_by: Vec::new(),
121            having: Vec::new(),
122            order_by: Vec::new(),
123            limit: None,
124            offset: None,
125            distinct: false,
126            row_lock_wait: None,
127            loads: NoLoad,
128            _marker: PhantomData,
129            _lock_marker: PhantomData,
130            _distinct_marker: PhantomData,
131            _group_marker: PhantomData,
132        }
133    }
134}
135
136impl<Out, Loads, Lock, DistinctState, GroupState> Select<Out, Loads, Lock, DistinctState, GroupState> {
137    pub fn table(&self) -> Table {
138        self.table
139    }
140
141    pub fn columns_ref(&self) -> Option<&[SelectItem]> {
142        self.columns.as_deref()
143    }
144
145    pub fn joins(&self) -> &[Join] {
146        &self.joins
147    }
148
149    pub fn select_only(mut self) -> Self {
150        self.columns = Some(Vec::new());
151        self
152    }
153
154    pub fn column<T>(mut self, expr: impl IntoExpr<T>) -> Self {
155        let item = SelectItem {
156            expr: expr.into_expr().node,
157            alias: None,
158        };
159        match &mut self.columns {
160            Some(columns) => columns.push(item),
161            None => self.columns = Some(vec![item]),
162        }
163        self
164    }
165
166    pub fn column_as<T>(mut self, expr: impl IntoExpr<T>, alias: &str) -> Self {
167        let item = SelectItem {
168            expr: expr.into_expr().node,
169            alias: Some(alias.to_string()),
170        };
171        match &mut self.columns {
172            Some(columns) => columns.push(item),
173            None => self.columns = Some(vec![item]),
174        }
175        self
176    }
177
178    pub fn filter(mut self, expr: Expr<bool>) -> Self {
179        self.filters.push(expr);
180        self
181    }
182
183    pub fn join<R>(mut self, rel: R) -> Self
184    where
185        R: RelationInfo<Parent = Out>,
186    {
187        let relation = rel.relation();
188        for (table, on) in relation.join_steps() {
189            self.joins.push(Join {
190                table,
191                on,
192                kind: JoinKind::Inner,
193            });
194        }
195        self
196    }
197
198    pub fn left_join<R>(mut self, rel: R) -> Self
199    where
200        R: RelationInfo<Parent = Out>,
201    {
202        let relation = rel.relation();
203        for (table, on) in relation.join_steps() {
204            self.joins.push(Join {
205                table,
206                on,
207                kind: JoinKind::Left,
208            });
209        }
210        self
211    }
212
213    pub fn join_on(mut self, table: Table, on: Expr<bool>) -> Self {
214        self.joins.push(Join {
215            table,
216            on,
217            kind: JoinKind::Inner,
218        });
219        self
220    }
221
222    pub fn left_join_on(mut self, table: Table, on: Expr<bool>) -> Self {
223        self.joins.push(Join {
224            table,
225            on,
226            kind: JoinKind::Left,
227        });
228        self
229    }
230
231    pub fn limit(mut self, limit: u64) -> Self {
232        self.limit = Some(limit);
233        self
234    }
235
236    pub fn offset(mut self, offset: u64) -> Self {
237        self.offset = Some(offset);
238        self
239    }
240
241    pub fn order_by(mut self, order: Order) -> Self {
242        self.order_by.push(order);
243        self
244    }
245
246    pub fn columns(mut self, columns: Vec<ColumnRef>) -> Self {
247        let items = columns
248            .into_iter()
249            .map(|col| SelectItem {
250                expr: ExprNode::Column(col),
251                alias: None,
252            })
253            .collect::<Vec<_>>();
254        self.columns = Some(items);
255        self
256    }
257
258    pub fn into_model<T>(self) -> Select<T, Loads, Lock, DistinctState, GroupState> {
259        Select {
260            table: self.table,
261            columns: self.columns,
262            joins: self.joins,
263            filters: self.filters,
264            group_by: self.group_by,
265            having: self.having,
266            order_by: self.order_by,
267            limit: self.limit,
268            offset: self.offset,
269            distinct: self.distinct,
270            row_lock_wait: self.row_lock_wait,
271            loads: self.loads,
272            _marker: PhantomData,
273            _lock_marker: PhantomData,
274            _distinct_marker: PhantomData,
275            _group_marker: PhantomData,
276        }
277    }
278
279    pub fn with<L>(self, load: L) -> Select<L::Out2, LoadChain<Loads, L>, Lock, DistinctState, GroupState>
280    where
281        L: ApplyLoad<Out>,
282    {
283        Select {
284            table: self.table,
285            columns: self.columns,
286            joins: self.joins,
287            filters: self.filters,
288            group_by: self.group_by,
289            having: self.having,
290            order_by: self.order_by,
291            limit: self.limit,
292            offset: self.offset,
293            distinct: self.distinct,
294            row_lock_wait: self.row_lock_wait,
295            loads: LoadChain { prev: self.loads, load },
296            _marker: PhantomData,
297            _lock_marker: PhantomData,
298            _distinct_marker: PhantomData,
299            _group_marker: PhantomData,
300        }
301    }
302
303    pub fn compile(&self) -> CompiledSql {
304        self.compile_inner(true, true, true)
305    }
306
307    pub fn compile_without_pagination(&self) -> CompiledSql {
308        self.compile_inner(false, false, false)
309    }
310
311    pub fn compile_with_extra(&self, extra_columns: &[SelectItem], extra_joins: &[Join]) -> CompiledSql {
312        self.compile_inner_with(extra_columns, extra_joins, true, true, true)
313    }
314
315    fn compile_inner(&self, include_order: bool, include_pagination: bool, include_locking: bool) -> CompiledSql {
316        self.compile_inner_with(&[], &[], include_order, include_pagination, include_locking)
317    }
318
319    fn compile_inner_with(
320        &self,
321        extra_columns: &[SelectItem],
322        extra_joins: &[Join],
323        include_order: bool,
324        include_pagination: bool,
325        include_locking: bool,
326    ) -> CompiledSql {
327        let mut builder = SqlBuilder::new();
328        builder.push_sql("SELECT ");
329        if self.distinct {
330            builder.push_sql("DISTINCT ");
331        }
332        match &self.columns {
333            Some(columns) => {
334                for (idx, col) in columns.iter().enumerate() {
335                    if idx > 0 {
336                        builder.push_sql(", ");
337                    }
338                    col.expr.to_sql(&mut builder);
339                    if let Some(alias) = &col.alias {
340                        builder.push_sql(" AS ");
341                        builder.push_sql(alias);
342                    }
343                }
344                if !extra_columns.is_empty() {
345                    for col in extra_columns {
346                        builder.push_sql(", ");
347                        col.expr.to_sql(&mut builder);
348                        if let Some(alias) = &col.alias {
349                            builder.push_sql(" AS ");
350                            builder.push_sql(alias);
351                        }
352                    }
353                }
354            }
355            None => {
356                builder.push_sql(self.table.qualifier());
357                builder.push_sql(".*");
358                if !extra_columns.is_empty() {
359                    for col in extra_columns {
360                        builder.push_sql(", ");
361                        col.expr.to_sql(&mut builder);
362                        if let Some(alias) = &col.alias {
363                            builder.push_sql(" AS ");
364                            builder.push_sql(alias);
365                        }
366                    }
367                }
368            }
369        }
370        builder.push_sql(" FROM ");
371        builder.push_sql(&self.table.qualified_name());
372        if let Some(alias) = self.table.alias {
373            builder.push_sql(" ");
374            builder.push_sql(alias);
375        }
376        for join in &self.joins {
377            builder.push_sql(match join.kind {
378                JoinKind::Inner => " JOIN ",
379                JoinKind::Left => " LEFT JOIN ",
380            });
381            builder.push_sql(&join.table.qualified_name());
382            if let Some(alias) = join.table.alias {
383                builder.push_sql(" ");
384                builder.push_sql(alias);
385            }
386            builder.push_sql(" ON ");
387            join.on.node.to_sql(&mut builder);
388        }
389        for join in extra_joins {
390            builder.push_sql(match join.kind {
391                JoinKind::Inner => " JOIN ",
392                JoinKind::Left => " LEFT JOIN ",
393            });
394            builder.push_sql(&join.table.qualified_name());
395            if let Some(alias) = join.table.alias {
396                builder.push_sql(" ");
397                builder.push_sql(alias);
398            }
399            builder.push_sql(" ON ");
400            join.on.node.to_sql(&mut builder);
401        }
402        if !self.filters.is_empty() {
403            builder.push_sql(" WHERE ");
404            for (idx, expr) in self.filters.iter().enumerate() {
405                if idx > 0 {
406                    builder.push_sql(" AND ");
407                }
408                expr.node.to_sql(&mut builder);
409            }
410        }
411        if !self.group_by.is_empty() {
412            builder.push_sql(" GROUP BY ");
413            for (idx, expr) in self.group_by.iter().enumerate() {
414                if idx > 0 {
415                    builder.push_sql(", ");
416                }
417                expr.to_sql(&mut builder);
418            }
419        }
420        if !self.having.is_empty() {
421            builder.push_sql(" HAVING ");
422            for (idx, expr) in self.having.iter().enumerate() {
423                if idx > 0 {
424                    builder.push_sql(" AND ");
425                }
426                expr.node.to_sql(&mut builder);
427            }
428        }
429        if include_order && !self.order_by.is_empty() {
430            builder.push_sql(" ORDER BY ");
431            for (idx, order) in self.order_by.iter().enumerate() {
432                if idx > 0 {
433                    builder.push_sql(", ");
434                }
435                match &order.expr {
436                    OrderExpr::Expr(expr) => expr.to_sql(&mut builder),
437                    OrderExpr::Alias(alias) => builder.push_sql(alias),
438                }
439                builder.push_sql(match order.direction {
440                    OrderDirection::Asc => " ASC",
441                    OrderDirection::Desc => " DESC",
442                });
443            }
444        }
445        if include_pagination {
446            if let Some(limit) = self.limit {
447                builder.push_sql(" LIMIT ");
448                builder.push_sql(&limit.to_string());
449            }
450            if let Some(offset) = self.offset {
451                builder.push_sql(" OFFSET ");
452                builder.push_sql(&offset.to_string());
453            }
454        }
455        if include_locking {
456            if let Some(wait) = self.row_lock_wait {
457                builder.push_sql(" FOR UPDATE");
458                if self
459                    .joins
460                    .iter()
461                    .chain(extra_joins.iter())
462                    .any(|join| matches!(join.kind, JoinKind::Left))
463                {
464                    builder.push_sql(" OF ");
465                    builder.push_sql(self.table.qualifier());
466                }
467                match wait {
468                    RowLockWait::Wait => {}
469                    RowLockWait::SkipLocked => builder.push_sql(" SKIP LOCKED"),
470                    RowLockWait::NoWait => builder.push_sql(" NOWAIT"),
471                }
472            }
473        }
474        builder.finish()
475    }
476
477    pub fn debug_sql(&self) -> String {
478        self.compile().sql
479    }
480
481    pub fn into_parts(self) -> (CompiledSql, Loads) {
482        let compiled = self.compile();
483        (compiled, self.loads)
484    }
485
486    pub fn into_parts_with_loads(self) -> (Select<Out, NoLoad, Lock, DistinctState, GroupState>, Loads) {
487        let Select {
488            table,
489            columns,
490            joins,
491            filters,
492            group_by,
493            having,
494            order_by,
495            limit,
496            offset,
497            distinct,
498            row_lock_wait,
499            loads,
500            _marker,
501            _lock_marker,
502            _distinct_marker,
503            _group_marker,
504        } = self;
505
506        let select = Select {
507            table,
508            columns,
509            joins,
510            filters,
511            group_by,
512            having,
513            order_by,
514            limit,
515            offset,
516            distinct,
517            row_lock_wait,
518            loads: NoLoad,
519            _marker,
520            _lock_marker: PhantomData,
521            _distinct_marker: PhantomData,
522            _group_marker: PhantomData,
523        };
524
525        (select, loads)
526    }
527}
528
529impl<Out, Loads, DistinctState, GroupState> Select<Out, Loads, NoRowLock, DistinctState, GroupState> {
530    pub fn group_by<T>(mut self, expr: impl IntoExpr<T>) -> Select<Out, Loads, NoRowLock, DistinctState, Grouped> {
531        self.group_by.push(expr.into_expr().node);
532        Select {
533            table: self.table,
534            columns: self.columns,
535            joins: self.joins,
536            filters: self.filters,
537            group_by: self.group_by,
538            having: self.having,
539            order_by: self.order_by,
540            limit: self.limit,
541            offset: self.offset,
542            distinct: self.distinct,
543            row_lock_wait: self.row_lock_wait,
544            loads: self.loads,
545            _marker: PhantomData,
546            _lock_marker: PhantomData,
547            _distinct_marker: PhantomData,
548            _group_marker: PhantomData,
549        }
550    }
551
552    pub fn having(mut self, expr: Expr<bool>) -> Select<Out, Loads, NoRowLock, DistinctState, Grouped> {
553        self.having.push(expr);
554        Select {
555            table: self.table,
556            columns: self.columns,
557            joins: self.joins,
558            filters: self.filters,
559            group_by: self.group_by,
560            having: self.having,
561            order_by: self.order_by,
562            limit: self.limit,
563            offset: self.offset,
564            distinct: self.distinct,
565            row_lock_wait: self.row_lock_wait,
566            loads: self.loads,
567            _marker: PhantomData,
568            _lock_marker: PhantomData,
569            _distinct_marker: PhantomData,
570            _group_marker: PhantomData,
571        }
572    }
573}
574
575impl<Out, Loads, GroupState> Select<Out, Loads, NoRowLock, NotDistinct, GroupState> {
576    pub fn distinct(mut self) -> Select<Out, Loads, NoRowLock, DistinctSelected, GroupState> {
577        self.distinct = true;
578        Select {
579            table: self.table,
580            columns: self.columns,
581            joins: self.joins,
582            filters: self.filters,
583            group_by: self.group_by,
584            having: self.having,
585            order_by: self.order_by,
586            limit: self.limit,
587            offset: self.offset,
588            distinct: self.distinct,
589            row_lock_wait: self.row_lock_wait,
590            loads: self.loads,
591            _marker: PhantomData,
592            _lock_marker: PhantomData,
593            _distinct_marker: PhantomData,
594            _group_marker: PhantomData,
595        }
596    }
597}
598
599impl<Out, Loads> Select<Out, Loads, NoRowLock, NotDistinct, NotGrouped> {
600    pub fn for_update(self) -> Select<Out, Loads, ForUpdateRowLock, NotDistinct, NotGrouped> {
601        Select {
602            table: self.table,
603            columns: self.columns,
604            joins: self.joins,
605            filters: self.filters,
606            group_by: self.group_by,
607            having: self.having,
608            order_by: self.order_by,
609            limit: self.limit,
610            offset: self.offset,
611            distinct: self.distinct,
612            row_lock_wait: Some(self.row_lock_wait.unwrap_or(RowLockWait::Wait)),
613            loads: self.loads,
614            _marker: PhantomData,
615            _lock_marker: PhantomData,
616            _distinct_marker: PhantomData,
617            _group_marker: PhantomData,
618        }
619    }
620}
621
622impl<Out, Loads, GroupState> Select<Out, Loads, NoRowLock, DistinctSelected, GroupState> {
623    pub fn distinct(self) -> Self {
624        self
625    }
626}
627
628impl<Out, Loads> Select<Out, Loads, ForUpdateRowLock, NotDistinct, NotGrouped> {
629    pub fn for_update(self) -> Self {
630        self
631    }
632
633    pub fn skip_locked(mut self) -> Self {
634        self.row_lock_wait = Some(RowLockWait::SkipLocked);
635        self
636    }
637
638    pub fn nowait(mut self) -> Self {
639        self.row_lock_wait = Some(RowLockWait::NoWait);
640        self
641    }
642}
643
644impl Order {
645    pub fn asc(expr: impl IntoOrderExpr) -> Self {
646        Self {
647            expr: expr.into_order_expr(),
648            direction: OrderDirection::Asc,
649        }
650    }
651
652    pub fn desc(expr: impl IntoOrderExpr) -> Self {
653        Self {
654            expr: expr.into_order_expr(),
655            direction: OrderDirection::Desc,
656        }
657    }
658
659    pub fn asc_alias(alias: &str) -> Self {
660        Self {
661            expr: OrderExpr::Alias(alias.to_string()),
662            direction: OrderDirection::Asc,
663        }
664    }
665
666    pub fn desc_alias(alias: &str) -> Self {
667        Self {
668            expr: OrderExpr::Alias(alias.to_string()),
669            direction: OrderDirection::Desc,
670        }
671    }
672}