gluesql_core/ast_builder/select/
project.rs

1use {
2    super::Prebuild,
3    crate::{
4        ast::Select,
5        ast_builder::{
6            ExprNode, FilterNode, GroupByNode, HashJoinNode, HavingNode, JoinConstraintNode,
7            JoinNode, LimitNode, OffsetNode, OrderByExprList, OrderByNode, QueryNode,
8            SelectItemList, SelectNode, TableFactorNode,
9        },
10        result::Result,
11    },
12};
13
14#[derive(Clone, Debug)]
15pub enum PrevNode<'a> {
16    Select(SelectNode<'a>),
17    GroupBy(GroupByNode<'a>),
18    Having(HavingNode<'a>),
19    Join(Box<JoinNode<'a>>),
20    JoinConstraint(Box<JoinConstraintNode<'a>>),
21    HashJoin(Box<HashJoinNode<'a>>),
22    Filter(FilterNode<'a>),
23}
24
25impl Prebuild<Select> for PrevNode<'_> {
26    fn prebuild(self) -> Result<Select> {
27        match self {
28            Self::Select(node) => node.prebuild(),
29            Self::GroupBy(node) => node.prebuild(),
30            Self::Having(node) => node.prebuild(),
31            Self::Join(node) => node.prebuild(),
32            Self::JoinConstraint(node) => node.prebuild(),
33            Self::HashJoin(node) => node.prebuild(),
34            Self::Filter(node) => node.prebuild(),
35        }
36    }
37}
38
39impl<'a> From<SelectNode<'a>> for PrevNode<'a> {
40    fn from(node: SelectNode<'a>) -> Self {
41        PrevNode::Select(node)
42    }
43}
44
45impl<'a> From<GroupByNode<'a>> for PrevNode<'a> {
46    fn from(node: GroupByNode<'a>) -> Self {
47        PrevNode::GroupBy(node)
48    }
49}
50
51impl<'a> From<HavingNode<'a>> for PrevNode<'a> {
52    fn from(node: HavingNode<'a>) -> Self {
53        PrevNode::Having(node)
54    }
55}
56
57impl<'a> From<JoinNode<'a>> for PrevNode<'a> {
58    fn from(node: JoinNode<'a>) -> Self {
59        PrevNode::Join(Box::new(node))
60    }
61}
62
63impl<'a> From<JoinConstraintNode<'a>> for PrevNode<'a> {
64    fn from(node: JoinConstraintNode<'a>) -> Self {
65        PrevNode::JoinConstraint(Box::new(node))
66    }
67}
68
69impl<'a> From<HashJoinNode<'a>> for PrevNode<'a> {
70    fn from(node: HashJoinNode<'a>) -> Self {
71        PrevNode::HashJoin(Box::new(node))
72    }
73}
74
75impl<'a> From<FilterNode<'a>> for PrevNode<'a> {
76    fn from(node: FilterNode<'a>) -> Self {
77        PrevNode::Filter(node)
78    }
79}
80
81#[derive(Clone, Debug)]
82pub struct ProjectNode<'a> {
83    prev_node: PrevNode<'a>,
84    select_items_list: Vec<SelectItemList<'a>>,
85}
86
87impl<'a> ProjectNode<'a> {
88    pub fn new<N: Into<PrevNode<'a>>, T: Into<SelectItemList<'a>>>(
89        prev_node: N,
90        select_items: T,
91    ) -> Self {
92        Self {
93            prev_node: prev_node.into(),
94            select_items_list: vec![select_items.into()],
95        }
96    }
97
98    #[must_use]
99    pub fn project<T: Into<SelectItemList<'a>>>(mut self, select_items: T) -> Self {
100        self.select_items_list.push(select_items.into());
101
102        self
103    }
104
105    pub fn alias_as(self, table_alias: &'a str) -> TableFactorNode<'a> {
106        QueryNode::ProjectNode(self).alias_as(table_alias)
107    }
108
109    pub fn order_by<T: Into<OrderByExprList<'a>>>(self, order_by_exprs: T) -> OrderByNode<'a> {
110        OrderByNode::new(self, order_by_exprs)
111    }
112
113    pub fn offset<T: Into<ExprNode<'a>>>(self, expr: T) -> OffsetNode<'a> {
114        OffsetNode::new(self, expr)
115    }
116
117    pub fn limit<T: Into<ExprNode<'a>>>(self, expr: T) -> LimitNode<'a> {
118        LimitNode::new(self, expr)
119    }
120}
121
122impl Prebuild<Select> for ProjectNode<'_> {
123    fn prebuild(self) -> Result<Select> {
124        let mut query: Select = self.prev_node.prebuild()?;
125        query.projection = self
126            .select_items_list
127            .into_iter()
128            .map(TryInto::try_into)
129            .collect::<Result<Vec<Vec<_>>>>()?
130            .into_iter()
131            .flatten()
132            .collect::<Vec<_>>();
133
134        Ok(query)
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use {
141        crate::{
142            ast::{
143                Join, JoinConstraint, JoinExecutor, JoinOperator, Query, Select, SetExpr,
144                Statement, TableFactor, TableWithJoins,
145            },
146            ast_builder::{Build, SelectItemList, col, table, test},
147        },
148        pretty_assertions::assert_eq,
149    };
150
151    #[test]
152    fn project() {
153        // select node -> project node -> build
154        let actual = table("Good").select().project("id").build();
155        let expected = "SELECT id FROM Good";
156        test(&actual, expected);
157
158        // select node -> project node -> build
159        let actual = table("Group").select().project("*, Group.*, name").build();
160        let expected = "SELECT *, Group.*, name FROM Group";
161        test(&actual, expected);
162
163        // project node -> project node -> build
164        let actual = table("Foo")
165            .select()
166            .project(vec!["col1", "col2"])
167            .project("col3")
168            .project(vec!["col4".into(), col("col5")])
169            .project(col("col6"))
170            .project("col7 as hello")
171            .build();
172        let expected = "
173            SELECT
174                col1, col2, col3,
175                col4, col5, col6,
176                col7 as hello
177            FROM
178                Foo
179        ";
180        test(&actual, expected);
181
182        // select node -> project node -> build
183        let actual = table("Aliased")
184            .select()
185            .project("1 + 1 as col1, col2")
186            .build();
187        let expected = "SELECT 1 + 1 as col1, col2 FROM Aliased";
188        test(&actual, expected);
189    }
190
191    #[test]
192    fn prev_nodes() {
193        // select node -> project node -> build
194        let actual = table("Foo").select().project("*").build();
195        let expected = "SELECT * FROM Foo";
196        test(&actual, expected);
197
198        // group by node -> project node -> build
199        let actual = table("Bar")
200            .select()
201            .group_by("city")
202            .project("city, COUNT(name) as num")
203            .build();
204        let expected = "
205            SELECT
206              city, COUNT(name) as num
207            FROM Bar
208            GROUP BY city
209        ";
210        test(&actual, expected);
211
212        // having node -> project node -> build
213        let actual = table("Cat")
214            .select()
215            .filter(r#"type = "cute""#)
216            .group_by("age")
217            .having("SUM(length) < 1000")
218            .project(col("age"))
219            .project("SUM(length)")
220            .build();
221        let expected = r#"
222            SELECT age, SUM(length)
223            FROM Cat
224            WHERE type = "cute"
225            GROUP BY age
226            HAVING SUM(length) < 1000;
227        "#;
228        test(&actual, expected);
229
230        // hash join node -> project node -> build
231        let actual = table("Player")
232            .select()
233            .join("PlayerItem")
234            .hash_executor("PlayerItem.user_id", "Player.id")
235            .project("Player.name, PlayerItem.name")
236            .build();
237        let expected = {
238            let join = Join {
239                relation: TableFactor::Table {
240                    name: "PlayerItem".to_owned(),
241                    alias: None,
242                    index: None,
243                },
244                join_operator: JoinOperator::Inner(JoinConstraint::None),
245                join_executor: JoinExecutor::Hash {
246                    key_expr: col("PlayerItem.user_id").try_into().unwrap(),
247                    value_expr: col("Player.id").try_into().unwrap(),
248                    where_clause: None,
249                },
250            };
251            let select = Select {
252                distinct: false,
253                projection: SelectItemList::from("Player.name, PlayerItem.name")
254                    .try_into()
255                    .unwrap(),
256                from: TableWithJoins {
257                    relation: TableFactor::Table {
258                        name: "Player".to_owned(),
259                        alias: None,
260                        index: None,
261                    },
262                    joins: vec![join],
263                },
264                selection: None,
265                group_by: Vec::new(),
266                having: None,
267            };
268
269            Ok(Statement::Query(Query {
270                body: SetExpr::Select(Box::new(select)),
271                order_by: Vec::new(),
272                limit: None,
273                offset: None,
274            }))
275        };
276        assert_eq!(actual, expected);
277
278        // select -> project -> derived subquery
279        let actual = table("Foo")
280            .select()
281            .project("id")
282            .alias_as("Sub")
283            .select()
284            .build();
285        let expected = "SELECT * FROM (SELECT id FROM Foo) Sub";
286        test(&actual, expected);
287    }
288}