gluesql_core/ast_builder/select/
limit.rs

1use {
2    super::{Prebuild, values::ValuesNode},
3    crate::{
4        ast::Query,
5        ast_builder::{
6            ExprNode, FilterNode, GroupByNode, HashJoinNode, HavingNode, JoinConstraintNode,
7            JoinNode, OrderByNode, ProjectNode, QueryNode, SelectNode, TableFactorNode,
8        },
9        result::Result,
10    },
11};
12
13#[derive(Clone, Debug)]
14pub enum PrevNode<'a> {
15    Select(SelectNode<'a>),
16    Values(ValuesNode<'a>),
17    GroupBy(GroupByNode<'a>),
18    Having(HavingNode<'a>),
19    Join(Box<JoinNode<'a>>),
20    JoinConstraint(Box<JoinConstraintNode<'a>>),
21    HashJoin(HashJoinNode<'a>),
22    Filter(FilterNode<'a>),
23    OrderBy(OrderByNode<'a>),
24    ProjectNode(Box<ProjectNode<'a>>),
25}
26
27impl<'a> Prebuild<Query> for PrevNode<'a> {
28    fn prebuild(self) -> Result<Query> {
29        match self {
30            Self::Select(node) => node.prebuild(),
31            Self::Values(node) => node.prebuild(),
32            Self::GroupBy(node) => node.prebuild(),
33            Self::Having(node) => node.prebuild(),
34            Self::Join(node) => node.prebuild(),
35            Self::JoinConstraint(node) => node.prebuild(),
36            Self::HashJoin(node) => node.prebuild(),
37            Self::Filter(node) => node.prebuild(),
38            Self::OrderBy(node) => node.prebuild(),
39            Self::ProjectNode(node) => node.prebuild(),
40        }
41    }
42}
43
44impl<'a> From<SelectNode<'a>> for PrevNode<'a> {
45    fn from(node: SelectNode<'a>) -> Self {
46        PrevNode::Select(node)
47    }
48}
49
50impl<'a> From<ValuesNode<'a>> for PrevNode<'a> {
51    fn from(node: ValuesNode<'a>) -> Self {
52        PrevNode::Values(node)
53    }
54}
55
56impl<'a> From<GroupByNode<'a>> for PrevNode<'a> {
57    fn from(node: GroupByNode<'a>) -> Self {
58        PrevNode::GroupBy(node)
59    }
60}
61
62impl<'a> From<HavingNode<'a>> for PrevNode<'a> {
63    fn from(node: HavingNode<'a>) -> Self {
64        PrevNode::Having(node)
65    }
66}
67
68impl<'a> From<JoinConstraintNode<'a>> for PrevNode<'a> {
69    fn from(node: JoinConstraintNode<'a>) -> Self {
70        PrevNode::JoinConstraint(Box::new(node))
71    }
72}
73
74impl<'a> From<JoinNode<'a>> for PrevNode<'a> {
75    fn from(node: JoinNode<'a>) -> Self {
76        PrevNode::Join(Box::new(node))
77    }
78}
79
80impl<'a> From<HashJoinNode<'a>> for PrevNode<'a> {
81    fn from(node: HashJoinNode<'a>) -> Self {
82        PrevNode::HashJoin(node)
83    }
84}
85
86impl<'a> From<FilterNode<'a>> for PrevNode<'a> {
87    fn from(node: FilterNode<'a>) -> Self {
88        PrevNode::Filter(node)
89    }
90}
91
92impl<'a> From<OrderByNode<'a>> for PrevNode<'a> {
93    fn from(node: OrderByNode<'a>) -> Self {
94        PrevNode::OrderBy(node)
95    }
96}
97
98impl<'a> From<ProjectNode<'a>> for PrevNode<'a> {
99    fn from(node: ProjectNode<'a>) -> Self {
100        PrevNode::ProjectNode(Box::new(node))
101    }
102}
103
104#[derive(Clone, Debug)]
105pub struct LimitNode<'a> {
106    prev_node: PrevNode<'a>,
107    expr: ExprNode<'a>,
108}
109
110impl<'a> LimitNode<'a> {
111    pub fn new<N: Into<PrevNode<'a>>, T: Into<ExprNode<'a>>>(prev_node: N, expr: T) -> Self {
112        Self {
113            prev_node: prev_node.into(),
114            expr: expr.into(),
115        }
116    }
117
118    pub fn alias_as(self, table_alias: &'a str) -> TableFactorNode<'a> {
119        QueryNode::LimitNode(self).alias_as(table_alias)
120    }
121}
122
123impl<'a> Prebuild<Query> for LimitNode<'a> {
124    fn prebuild(self) -> Result<Query> {
125        let mut node_data = self.prev_node.prebuild()?;
126        node_data.limit = Some(self.expr.try_into()?);
127
128        Ok(node_data)
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use {
135        crate::{
136            ast::{
137                Join, JoinConstraint, JoinExecutor, JoinOperator, Query, Select, SetExpr,
138                Statement, TableFactor, TableWithJoins,
139            },
140            ast_builder::{Build, SelectItemList, col, num, table, test},
141        },
142        pretty_assertions::assert_eq,
143    };
144
145    #[test]
146    fn limit() {
147        // select node -> limit node -> build
148        let actual = table("Foo").select().limit(10).build();
149        let expected = "SELECT * FROM Foo LIMIT 10";
150        test(actual, expected);
151
152        // group by node -> limit node -> build
153        let actual = table("Foo").select().group_by("bar").limit(10).build();
154        let expected = "SELECT * FROM Foo GROUP BY bar LIMIT 10";
155        test(actual, expected);
156
157        // having node -> limit node -> build
158        let actual = table("Foo")
159            .select()
160            .group_by("bar")
161            .having("bar = 10")
162            .limit(10)
163            .build();
164        let expected = "SELECT * FROM Foo GROUP BY bar HAVING bar = 10 LIMIT 10";
165        test(actual, expected);
166
167        // join node -> limit node -> build
168        let actual = table("Foo").select().join("Bar").limit(10).build();
169        let expected = "SELECT * FROM Foo JOIN Bar LIMIT 10";
170        test(actual, expected);
171
172        // join node -> limit node -> build
173        let actual = table("Foo").select().join_as("Bar", "B").limit(10).build();
174        let expected = "SELECT * FROM Foo JOIN Bar AS B LIMIT 10";
175        test(actual, expected);
176
177        // join node -> limit node -> build
178        let actual = table("Foo").select().left_join("Bar").limit(10).build();
179        let expected = "SELECT * FROM Foo LEFT JOIN Bar LIMIT 10";
180        test(actual, expected);
181
182        // join node -> limit node -> build
183        let actual = table("Foo")
184            .select()
185            .left_join_as("Bar", "B")
186            .limit(10)
187            .build();
188        let expected = "SELECT * FROM Foo LEFT JOIN Bar AS B LIMIT 10";
189        test(actual, expected);
190
191        // group by node -> limit node -> build
192        let actual = table("Foo").select().group_by("id").limit(10).build();
193        let expected = "SELECT * FROM Foo GROUP BY id LIMIT 10";
194        test(actual, expected);
195
196        // having node -> limit node -> build
197        let actual = table("Foo")
198            .select()
199            .group_by("id")
200            .having(col("id").gt(10))
201            .limit(10)
202            .build();
203        let expected = "SELECT * FROM Foo GROUP BY id HAVING id > 10 LIMIT 10";
204        test(actual, expected);
205
206        // join constraint node -> limit node -> build
207        let actual = table("Foo")
208            .select()
209            .join("Bar")
210            .on("Foo.id = Bar.id")
211            .limit(10)
212            .build();
213        let expected = "SELECT * FROM Foo JOIN Bar ON Foo.id = Bar.id LIMIT 10";
214        test(actual, expected);
215
216        // filter node -> limit node -> build
217        let actual = table("World")
218            .select()
219            .filter(col("id").gt(2))
220            .limit(100)
221            .build();
222        let expected = "SELECT * FROM World WHERE id > 2 LIMIT 100";
223        test(actual, expected);
224
225        // order by node -> limit node -> build
226        let actual = table("Hello").select().order_by("score").limit(3).build();
227        let expected = "SELECT * FROM Hello ORDER BY score LIMIT 3";
228        test(actual, expected);
229
230        // project node -> limit node -> build
231        let actual = table("Item").select().project("*").limit(10).build();
232        let expected = "SELECT * FROM Item LIMIT 10";
233        test(actual, expected);
234
235        // hash join node -> limit node -> build
236        let actual = table("Player")
237            .select()
238            .join("PlayerItem")
239            .hash_executor("PlayerItem.user_id", "Player.id")
240            .limit(100)
241            .build();
242        let expected = {
243            let join = Join {
244                relation: TableFactor::Table {
245                    name: "PlayerItem".to_owned(),
246                    alias: None,
247                    index: None,
248                },
249                join_operator: JoinOperator::Inner(JoinConstraint::None),
250                join_executor: JoinExecutor::Hash {
251                    key_expr: col("PlayerItem.user_id").try_into().unwrap(),
252                    value_expr: col("Player.id").try_into().unwrap(),
253                    where_clause: None,
254                },
255            };
256            let select = Select {
257                distinct: false,
258                projection: SelectItemList::from("*").try_into().unwrap(),
259                from: TableWithJoins {
260                    relation: TableFactor::Table {
261                        name: "Player".to_owned(),
262                        alias: None,
263                        index: None,
264                    },
265                    joins: vec![join],
266                },
267                selection: None,
268                group_by: Vec::new(),
269                having: None,
270            };
271
272            Ok(Statement::Query(Query {
273                body: SetExpr::Select(Box::new(select)),
274                order_by: Vec::new(),
275                limit: Some(num(100).try_into().unwrap()),
276                offset: None,
277            }))
278        };
279        assert_eq!(actual, expected);
280
281        // select node -> limit node -> derived subquery
282        let actual = table("Foo")
283            .select()
284            .limit(10)
285            .alias_as("Sub")
286            .select()
287            .build();
288        let expected = "SELECT * FROM (SELECT * FROM Foo LIMIT 10) Sub";
289        test(actual, expected);
290    }
291}