gluesql_core/ast_builder/select/
project.rs

1use {
2    super::Prebuild,
3    crate::{
4        ast::Select,
5        ast_builder::{
6            ExprNode, FilterNode, GroupByNode, HashJoinNode, HavingNode, JoinConstraintNode,
7            JoinNode, LimitNode, OffsetNode, OrderByExprList, OrderByNode, QueryNode,
8            SelectItemList, SelectNode, TableFactorNode,
9        },
10        result::Result,
11    },
12};
13
14#[derive(Clone, Debug)]
15pub enum PrevNode<'a> {
16    Select(SelectNode<'a>),
17    GroupBy(GroupByNode<'a>),
18    Having(HavingNode<'a>),
19    Join(Box<JoinNode<'a>>),
20    JoinConstraint(Box<JoinConstraintNode<'a>>),
21    HashJoin(Box<HashJoinNode<'a>>),
22    Filter(FilterNode<'a>),
23}
24
25impl<'a> Prebuild<Select> for PrevNode<'a> {
26    fn prebuild(self) -> Result<Select> {
27        match self {
28            Self::Select(node) => node.prebuild(),
29            Self::GroupBy(node) => node.prebuild(),
30            Self::Having(node) => node.prebuild(),
31            Self::Join(node) => node.prebuild(),
32            Self::JoinConstraint(node) => node.prebuild(),
33            Self::HashJoin(node) => node.prebuild(),
34            Self::Filter(node) => node.prebuild(),
35        }
36    }
37}
38
39impl<'a> From<SelectNode<'a>> for PrevNode<'a> {
40    fn from(node: SelectNode<'a>) -> Self {
41        PrevNode::Select(node)
42    }
43}
44
45impl<'a> From<GroupByNode<'a>> for PrevNode<'a> {
46    fn from(node: GroupByNode<'a>) -> Self {
47        PrevNode::GroupBy(node)
48    }
49}
50
51impl<'a> From<HavingNode<'a>> for PrevNode<'a> {
52    fn from(node: HavingNode<'a>) -> Self {
53        PrevNode::Having(node)
54    }
55}
56
57impl<'a> From<JoinNode<'a>> for PrevNode<'a> {
58    fn from(node: JoinNode<'a>) -> Self {
59        PrevNode::Join(Box::new(node))
60    }
61}
62
63impl<'a> From<JoinConstraintNode<'a>> for PrevNode<'a> {
64    fn from(node: JoinConstraintNode<'a>) -> Self {
65        PrevNode::JoinConstraint(Box::new(node))
66    }
67}
68
69impl<'a> From<HashJoinNode<'a>> for PrevNode<'a> {
70    fn from(node: HashJoinNode<'a>) -> Self {
71        PrevNode::HashJoin(Box::new(node))
72    }
73}
74
75impl<'a> From<FilterNode<'a>> for PrevNode<'a> {
76    fn from(node: FilterNode<'a>) -> Self {
77        PrevNode::Filter(node)
78    }
79}
80
81#[derive(Clone, Debug)]
82pub struct ProjectNode<'a> {
83    prev_node: PrevNode<'a>,
84    select_items_list: Vec<SelectItemList<'a>>,
85}
86
87impl<'a> ProjectNode<'a> {
88    pub fn new<N: Into<PrevNode<'a>>, T: Into<SelectItemList<'a>>>(
89        prev_node: N,
90        select_items: T,
91    ) -> Self {
92        Self {
93            prev_node: prev_node.into(),
94            select_items_list: vec![select_items.into()],
95        }
96    }
97
98    pub fn project<T: Into<SelectItemList<'a>>>(mut self, select_items: T) -> Self {
99        self.select_items_list.push(select_items.into());
100
101        self
102    }
103
104    pub fn alias_as(self, table_alias: &'a str) -> TableFactorNode<'a> {
105        QueryNode::ProjectNode(self).alias_as(table_alias)
106    }
107
108    pub fn order_by<T: Into<OrderByExprList<'a>>>(self, order_by_exprs: T) -> OrderByNode<'a> {
109        OrderByNode::new(self, order_by_exprs)
110    }
111
112    pub fn offset<T: Into<ExprNode<'a>>>(self, expr: T) -> OffsetNode<'a> {
113        OffsetNode::new(self, expr)
114    }
115
116    pub fn limit<T: Into<ExprNode<'a>>>(self, expr: T) -> LimitNode<'a> {
117        LimitNode::new(self, expr)
118    }
119}
120
121impl<'a> Prebuild<Select> for ProjectNode<'a> {
122    fn prebuild(self) -> Result<Select> {
123        let mut query: Select = self.prev_node.prebuild()?;
124        query.projection = self
125            .select_items_list
126            .into_iter()
127            .map(TryInto::try_into)
128            .collect::<Result<Vec<Vec<_>>>>()?
129            .into_iter()
130            .flatten()
131            .collect::<Vec<_>>();
132
133        Ok(query)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use {
140        crate::{
141            ast::{
142                Join, JoinConstraint, JoinExecutor, JoinOperator, Query, Select, SetExpr,
143                Statement, TableFactor, TableWithJoins,
144            },
145            ast_builder::{Build, SelectItemList, col, table, test},
146        },
147        pretty_assertions::assert_eq,
148    };
149
150    #[test]
151    fn project() {
152        // select node -> project node -> build
153        let actual = table("Good").select().project("id").build();
154        let expected = "SELECT id FROM Good";
155        test(actual, expected);
156
157        // select node -> project node -> build
158        let actual = table("Group").select().project("*, Group.*, name").build();
159        let expected = "SELECT *, Group.*, name FROM Group";
160        test(actual, expected);
161
162        // project node -> project node -> build
163        let actual = table("Foo")
164            .select()
165            .project(vec!["col1", "col2"])
166            .project("col3")
167            .project(vec!["col4".into(), col("col5")])
168            .project(col("col6"))
169            .project("col7 as hello")
170            .build();
171        let expected = "
172            SELECT
173                col1, col2, col3,
174                col4, col5, col6,
175                col7 as hello
176            FROM
177                Foo
178        ";
179        test(actual, expected);
180
181        // select node -> project node -> build
182        let actual = table("Aliased")
183            .select()
184            .project("1 + 1 as col1, col2")
185            .build();
186        let expected = "SELECT 1 + 1 as col1, col2 FROM Aliased";
187        test(actual, expected);
188    }
189
190    #[test]
191    fn prev_nodes() {
192        // select node -> project node -> build
193        let actual = table("Foo").select().project("*").build();
194        let expected = "SELECT * FROM Foo";
195        test(actual, expected);
196
197        // group by node -> project node -> build
198        let actual = table("Bar")
199            .select()
200            .group_by("city")
201            .project("city, COUNT(name) as num")
202            .build();
203        let expected = "
204            SELECT
205              city, COUNT(name) as num
206            FROM Bar
207            GROUP BY city
208        ";
209        test(actual, expected);
210
211        // having node -> project node -> build
212        let actual = table("Cat")
213            .select()
214            .filter(r#"type = "cute""#)
215            .group_by("age")
216            .having("SUM(length) < 1000")
217            .project(col("age"))
218            .project("SUM(length)")
219            .build();
220        let expected = r#"
221            SELECT age, SUM(length)
222            FROM Cat
223            WHERE type = "cute"
224            GROUP BY age
225            HAVING SUM(length) < 1000;
226        "#;
227        test(actual, expected);
228
229        // hash join node -> project node -> build
230        let actual = table("Player")
231            .select()
232            .join("PlayerItem")
233            .hash_executor("PlayerItem.user_id", "Player.id")
234            .project("Player.name, PlayerItem.name")
235            .build();
236        let expected = {
237            let join = Join {
238                relation: TableFactor::Table {
239                    name: "PlayerItem".to_owned(),
240                    alias: None,
241                    index: None,
242                },
243                join_operator: JoinOperator::Inner(JoinConstraint::None),
244                join_executor: JoinExecutor::Hash {
245                    key_expr: col("PlayerItem.user_id").try_into().unwrap(),
246                    value_expr: col("Player.id").try_into().unwrap(),
247                    where_clause: None,
248                },
249            };
250            let select = Select {
251                distinct: false,
252                projection: SelectItemList::from("Player.name, PlayerItem.name")
253                    .try_into()
254                    .unwrap(),
255                from: TableWithJoins {
256                    relation: TableFactor::Table {
257                        name: "Player".to_owned(),
258                        alias: None,
259                        index: None,
260                    },
261                    joins: vec![join],
262                },
263                selection: None,
264                group_by: Vec::new(),
265                having: None,
266            };
267
268            Ok(Statement::Query(Query {
269                body: SetExpr::Select(Box::new(select)),
270                order_by: Vec::new(),
271                limit: None,
272                offset: None,
273            }))
274        };
275        assert_eq!(actual, expected);
276
277        // select -> project -> derived subquery
278        let actual = table("Foo")
279            .select()
280            .project("id")
281            .alias_as("Sub")
282            .select()
283            .build();
284        let expected = "SELECT * FROM (SELECT id FROM Foo) Sub";
285        test(actual, expected);
286    }
287}