nom_sql/
select.rs

1use nom::character::complete::{multispace0, multispace1};
2use std::fmt;
3use std::str;
4
5use column::Column;
6use common::FieldDefinitionExpression;
7use common::{
8    as_alias, field_definition_expr, field_list, statement_terminator, table_list, table_reference,
9    unsigned_number,
10};
11use condition::{condition_expr, ConditionExpression};
12use join::{join_operator, JoinConstraint, JoinOperator, JoinRightSide};
13use nom::branch::alt;
14use nom::bytes::complete::{tag, tag_no_case};
15use nom::combinator::{map, opt};
16use nom::multi::many0;
17use nom::sequence::{delimited, preceded, terminated, tuple};
18use nom::IResult;
19use order::{order_clause, OrderClause};
20use table::Table;
21
22#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
23pub struct GroupByClause {
24    pub columns: Vec<Column>,
25    pub having: Option<ConditionExpression>,
26}
27
28impl fmt::Display for GroupByClause {
29    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30        write!(f, "GROUP BY ")?;
31        write!(
32            f,
33            "{}",
34            self.columns
35                .iter()
36                .map(|c| format!("{}", c))
37                .collect::<Vec<_>>()
38                .join(", ")
39        )?;
40        if let Some(ref having) = self.having {
41            write!(f, " HAVING {}", having)?;
42        }
43        Ok(())
44    }
45}
46
47#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
48pub struct JoinClause {
49    pub operator: JoinOperator,
50    pub right: JoinRightSide,
51    pub constraint: JoinConstraint,
52}
53
54impl fmt::Display for JoinClause {
55    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
56        write!(f, "{}", self.operator)?;
57        write!(f, " {}", self.right)?;
58        write!(f, " {}", self.constraint)?;
59        Ok(())
60    }
61}
62
63#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
64pub struct LimitClause {
65    pub limit: u64,
66    pub offset: u64,
67}
68
69impl fmt::Display for LimitClause {
70    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
71        write!(f, "LIMIT {}", self.limit)?;
72        if self.offset > 0 {
73            write!(f, " OFFSET {}", self.offset)?;
74        }
75        Ok(())
76    }
77}
78
79#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
80pub struct SelectStatement {
81    pub tables: Vec<Table>,
82    pub distinct: bool,
83    pub fields: Vec<FieldDefinitionExpression>,
84    pub join: Vec<JoinClause>,
85    pub where_clause: Option<ConditionExpression>,
86    pub group_by: Option<GroupByClause>,
87    pub order: Option<OrderClause>,
88    pub limit: Option<LimitClause>,
89}
90
91impl fmt::Display for SelectStatement {
92    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93        write!(f, "SELECT ")?;
94        if self.distinct {
95            write!(f, "DISTINCT ")?;
96        }
97        write!(
98            f,
99            "{}",
100            self.fields
101                .iter()
102                .map(|field| format!("{}", field))
103                .collect::<Vec<_>>()
104                .join(", ")
105        )?;
106
107        if self.tables.len() > 0 {
108            write!(f, " FROM ")?;
109            write!(
110                f,
111                "{}",
112                self.tables
113                    .iter()
114                    .map(|table| format!("{}", table))
115                    .collect::<Vec<_>>()
116                    .join(", ")
117            )?;
118        }
119        for jc in &self.join {
120            write!(f, " {}", jc)?;
121        }
122        if let Some(ref where_clause) = self.where_clause {
123            write!(f, " WHERE ")?;
124            write!(f, "{}", where_clause)?;
125        }
126        if let Some(ref group_by) = self.group_by {
127            write!(f, " {}", group_by)?;
128        }
129        if let Some(ref order) = self.order {
130            write!(f, " {}", order)?;
131        }
132        if let Some(ref limit) = self.limit {
133            write!(f, " {}", limit)?;
134        }
135        Ok(())
136    }
137}
138
139fn having_clause(i: &[u8]) -> IResult<&[u8], ConditionExpression> {
140    let (remaining_input, (_, _, _, ce)) = tuple((
141        multispace0,
142        tag_no_case("having"),
143        multispace1,
144        condition_expr,
145    ))(i)?;
146
147    Ok((remaining_input, ce))
148}
149
150// Parse GROUP BY clause
151pub fn group_by_clause(i: &[u8]) -> IResult<&[u8], GroupByClause> {
152    let (remaining_input, (_, _, _, columns, having)) = tuple((
153        multispace0,
154        tag_no_case("group by"),
155        multispace1,
156        field_list,
157        opt(having_clause),
158    ))(i)?;
159
160    Ok((remaining_input, GroupByClause { columns, having }))
161}
162
163fn offset(i: &[u8]) -> IResult<&[u8], u64> {
164    let (remaining_input, (_, _, _, val)) = tuple((
165        multispace0,
166        tag_no_case("offset"),
167        multispace1,
168        unsigned_number,
169    ))(i)?;
170
171    Ok((remaining_input, val))
172}
173
174// Parse LIMIT clause
175pub fn limit_clause(i: &[u8]) -> IResult<&[u8], LimitClause> {
176    let (remaining_input, (_, _, _, limit, opt_offset)) = tuple((
177        multispace0,
178        tag_no_case("limit"),
179        multispace1,
180        unsigned_number,
181        opt(offset),
182    ))(i)?;
183    let offset = match opt_offset {
184        None => 0,
185        Some(v) => v,
186    };
187
188    Ok((remaining_input, LimitClause { limit, offset }))
189}
190
191fn join_constraint(i: &[u8]) -> IResult<&[u8], JoinConstraint> {
192    let using_clause = map(
193        tuple((
194            tag_no_case("using"),
195            multispace1,
196            delimited(
197                terminated(tag("("), multispace0),
198                field_list,
199                preceded(multispace0, tag(")")),
200            ),
201        )),
202        |t| JoinConstraint::Using(t.2),
203    );
204    let on_condition = alt((
205        delimited(
206            terminated(tag("("), multispace0),
207            condition_expr,
208            preceded(multispace0, tag(")")),
209        ),
210        condition_expr,
211    ));
212    let on_clause = map(tuple((tag_no_case("on"), multispace1, on_condition)), |t| {
213        JoinConstraint::On(t.2)
214    });
215
216    alt((using_clause, on_clause))(i)
217}
218
219// Parse JOIN clause
220fn join_clause(i: &[u8]) -> IResult<&[u8], JoinClause> {
221    let (remaining_input, (_, _natural, operator, _, right, _, constraint)) = tuple((
222        multispace0,
223        opt(terminated(tag_no_case("natural"), multispace1)),
224        join_operator,
225        multispace1,
226        join_rhs,
227        multispace1,
228        join_constraint,
229    ))(i)?;
230
231    Ok((
232        remaining_input,
233        JoinClause {
234            operator,
235            right,
236            constraint,
237        },
238    ))
239}
240
241fn join_rhs(i: &[u8]) -> IResult<&[u8], JoinRightSide> {
242    let nested_select = map(
243        tuple((
244            delimited(tag("("), nested_selection, tag(")")),
245            opt(as_alias),
246        )),
247        |t| JoinRightSide::NestedSelect(Box::new(t.0), t.1.map(String::from)),
248    );
249    let nested_join = map(delimited(tag("("), join_clause, tag(")")), |nj| {
250        JoinRightSide::NestedJoin(Box::new(nj))
251    });
252    let table = map(table_reference, |t| JoinRightSide::Table(t));
253    let tables = map(delimited(tag("("), table_list, tag(")")), |tables| {
254        JoinRightSide::Tables(tables)
255    });
256    alt((nested_select, nested_join, table, tables))(i)
257}
258
259// Parse WHERE clause of a selection
260pub fn where_clause(i: &[u8]) -> IResult<&[u8], ConditionExpression> {
261    let (remaining_input, (_, _, _, where_condition)) = tuple((
262        multispace0,
263        tag_no_case("where"),
264        multispace1,
265        condition_expr,
266    ))(i)?;
267
268    Ok((remaining_input, where_condition))
269}
270
271// Parse rule for a SQL selection query.
272pub fn selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
273    terminated(nested_selection, statement_terminator)(i)
274}
275
276pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
277    let (
278        remaining_input,
279        (_, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit),
280    ) = tuple((
281        tag_no_case("select"),
282        multispace1,
283        opt(tag_no_case("distinct")),
284        multispace0,
285        field_definition_expr,
286        delimited(multispace0, tag_no_case("from"), multispace0),
287        table_list,
288        many0(join_clause),
289        opt(where_clause),
290        opt(group_by_clause),
291        opt(order_clause),
292        opt(limit_clause),
293    ))(i)?;
294    Ok((
295        remaining_input,
296        SelectStatement {
297            tables,
298            distinct: distinct.is_some(),
299            fields,
300            join,
301            where_clause,
302            group_by,
303            order,
304            limit,
305        },
306    ))
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use case::{CaseWhenExpression, ColumnOrLiteral};
313    use column::{Column, FunctionArguments, FunctionExpression};
314    use common::{FieldDefinitionExpression, FieldValueExpression, Literal, Operator};
315    use condition::ConditionBase::*;
316    use condition::ConditionExpression::*;
317    use condition::ConditionTree;
318    use order::OrderType;
319    use table::Table;
320
321    fn columns(cols: &[&str]) -> Vec<FieldDefinitionExpression> {
322        cols.iter()
323            .map(|c| FieldDefinitionExpression::Col(Column::from(*c)))
324            .collect()
325    }
326
327    #[test]
328    fn simple_select() {
329        let qstring = "SELECT id, name FROM users;";
330
331        let res = selection(qstring.as_bytes());
332        assert_eq!(
333            res.unwrap().1,
334            SelectStatement {
335                tables: vec![Table::from("users")],
336                fields: columns(&["id", "name"]),
337                ..Default::default()
338            }
339        );
340    }
341
342    #[test]
343    fn more_involved_select() {
344        let qstring = "SELECT users.id, users.name FROM users;";
345
346        let res = selection(qstring.as_bytes());
347        assert_eq!(
348            res.unwrap().1,
349            SelectStatement {
350                tables: vec![Table::from("users")],
351                fields: columns(&["users.id", "users.name"]),
352                ..Default::default()
353            }
354        );
355    }
356
357    #[test]
358    fn select_literals() {
359        use common::Literal;
360
361        let qstring = "SELECT NULL, 1, \"foo\", CURRENT_TIME FROM users;";
362        // TODO: doesn't support selecting literals without a FROM clause, which is still valid SQL
363        //        let qstring = "SELECT NULL, 1, \"foo\";";
364
365        let res = selection(qstring.as_bytes());
366        assert_eq!(
367            res.unwrap().1,
368            SelectStatement {
369                tables: vec![Table::from("users")],
370                fields: vec![
371                    FieldDefinitionExpression::Value(FieldValueExpression::Literal(
372                        Literal::Null.into(),
373                    )),
374                    FieldDefinitionExpression::Value(FieldValueExpression::Literal(
375                        Literal::Integer(1).into(),
376                    )),
377                    FieldDefinitionExpression::Value(FieldValueExpression::Literal(
378                        Literal::String("foo".to_owned()).into(),
379                    )),
380                    FieldDefinitionExpression::Value(FieldValueExpression::Literal(
381                        Literal::CurrentTime.into(),
382                    )),
383                ],
384                ..Default::default()
385            }
386        );
387    }
388
389    #[test]
390    fn select_all() {
391        let qstring = "SELECT * FROM users;";
392
393        let res = selection(qstring.as_bytes());
394        assert_eq!(
395            res.unwrap().1,
396            SelectStatement {
397                tables: vec![Table::from("users")],
398                fields: vec![FieldDefinitionExpression::All],
399                ..Default::default()
400            }
401        );
402    }
403
404    #[test]
405    fn select_all_in_table() {
406        let qstring = "SELECT users.* FROM users, votes;";
407
408        let res = selection(qstring.as_bytes());
409        assert_eq!(
410            res.unwrap().1,
411            SelectStatement {
412                tables: vec![Table::from("users"), Table::from("votes")],
413                fields: vec![FieldDefinitionExpression::AllInTable(String::from("users"))],
414                ..Default::default()
415            }
416        );
417    }
418
419    #[test]
420    fn spaces_optional() {
421        let qstring = "SELECT id,name FROM users;";
422
423        let res = selection(qstring.as_bytes());
424        assert_eq!(
425            res.unwrap().1,
426            SelectStatement {
427                tables: vec![Table::from("users")],
428                fields: columns(&["id", "name"]),
429                ..Default::default()
430            }
431        );
432    }
433
434    #[test]
435    fn case_sensitivity() {
436        let qstring_lc = "select id, name from users;";
437        let qstring_uc = "SELECT id, name FROM users;";
438
439        assert_eq!(
440            selection(qstring_lc.as_bytes()).unwrap(),
441            selection(qstring_uc.as_bytes()).unwrap()
442        );
443    }
444
445    #[test]
446    fn termination() {
447        let qstring_sem = "select id, name from users;";
448        let qstring_nosem = "select id, name from users";
449        let qstring_linebreak = "select id, name from users\n";
450
451        let r1 = selection(qstring_sem.as_bytes()).unwrap();
452        let r2 = selection(qstring_nosem.as_bytes()).unwrap();
453        let r3 = selection(qstring_linebreak.as_bytes()).unwrap();
454        assert_eq!(r1, r2);
455        assert_eq!(r2, r3);
456    }
457
458    #[test]
459    fn where_clause() {
460        let qstring = "select * from ContactInfo where email=?;";
461
462        let res = selection(qstring.as_bytes());
463
464        let expected_left = Base(Field(Column::from("email")));
465        let expected_where_cond = Some(ComparisonOp(ConditionTree {
466            left: Box::new(expected_left),
467            right: Box::new(Base(Literal(Literal::Placeholder))),
468            operator: Operator::Equal,
469        }));
470        assert_eq!(
471            res.unwrap().1,
472            SelectStatement {
473                tables: vec![Table::from("ContactInfo")],
474                fields: vec![FieldDefinitionExpression::All],
475                where_clause: expected_where_cond,
476                ..Default::default()
477            }
478        );
479    }
480
481    #[test]
482    fn limit_clause() {
483        let qstring1 = "select * from users limit 10\n";
484        let qstring2 = "select * from users limit 10 offset 10\n";
485
486        let expected_lim1 = LimitClause {
487            limit: 10,
488            offset: 0,
489        };
490        let expected_lim2 = LimitClause {
491            limit: 10,
492            offset: 10,
493        };
494
495        let res1 = selection(qstring1.as_bytes());
496        let res2 = selection(qstring2.as_bytes());
497        assert_eq!(res1.unwrap().1.limit, Some(expected_lim1));
498        assert_eq!(res2.unwrap().1.limit, Some(expected_lim2));
499    }
500
501    #[test]
502    fn table_alias() {
503        let qstring1 = "select * from PaperTag as t;";
504        // let qstring2 = "select * from PaperTag t;";
505
506        let res1 = selection(qstring1.as_bytes());
507        assert_eq!(
508            res1.clone().unwrap().1,
509            SelectStatement {
510                tables: vec![Table {
511                    name: String::from("PaperTag"),
512                    alias: Some(String::from("t")),
513                },],
514                fields: vec![FieldDefinitionExpression::All],
515                ..Default::default()
516            }
517        );
518        // let res2 = selection(qstring2.as_bytes());
519        // assert_eq!(res1.unwrap().1, res2.unwrap().1);
520    }
521
522    #[test]
523    fn column_alias() {
524        let qstring1 = "select name as TagName from PaperTag;";
525        let qstring2 = "select PaperTag.name as TagName from PaperTag;";
526
527        let res1 = selection(qstring1.as_bytes());
528        assert_eq!(
529            res1.clone().unwrap().1,
530            SelectStatement {
531                tables: vec![Table::from("PaperTag")],
532                fields: vec![FieldDefinitionExpression::Col(Column {
533                    name: String::from("name"),
534                    alias: Some(String::from("TagName")),
535                    table: None,
536                    function: None,
537                }),],
538                ..Default::default()
539            }
540        );
541        let res2 = selection(qstring2.as_bytes());
542        assert_eq!(
543            res2.clone().unwrap().1,
544            SelectStatement {
545                tables: vec![Table::from("PaperTag")],
546                fields: vec![FieldDefinitionExpression::Col(Column {
547                    name: String::from("name"),
548                    alias: Some(String::from("TagName")),
549                    table: Some(String::from("PaperTag")),
550                    function: None,
551                }),],
552                ..Default::default()
553            }
554        );
555    }
556
557    #[test]
558    fn column_alias_no_as() {
559        let qstring1 = "select name TagName from PaperTag;";
560        let qstring2 = "select PaperTag.name TagName from PaperTag;";
561
562        let res1 = selection(qstring1.as_bytes());
563        assert_eq!(
564            res1.clone().unwrap().1,
565            SelectStatement {
566                tables: vec![Table::from("PaperTag")],
567                fields: vec![FieldDefinitionExpression::Col(Column {
568                    name: String::from("name"),
569                    alias: Some(String::from("TagName")),
570                    table: None,
571                    function: None,
572                }),],
573                ..Default::default()
574            }
575        );
576        let res2 = selection(qstring2.as_bytes());
577        assert_eq!(
578            res2.clone().unwrap().1,
579            SelectStatement {
580                tables: vec![Table::from("PaperTag")],
581                fields: vec![FieldDefinitionExpression::Col(Column {
582                    name: String::from("name"),
583                    alias: Some(String::from("TagName")),
584                    table: Some(String::from("PaperTag")),
585                    function: None,
586                }),],
587                ..Default::default()
588            }
589        );
590    }
591
592    #[test]
593    fn distinct() {
594        let qstring = "select distinct tag from PaperTag where paperId=?;";
595
596        let res = selection(qstring.as_bytes());
597        let expected_left = Base(Field(Column::from("paperId")));
598        let expected_where_cond = Some(ComparisonOp(ConditionTree {
599            left: Box::new(expected_left),
600            right: Box::new(Base(Literal(Literal::Placeholder))),
601            operator: Operator::Equal,
602        }));
603        assert_eq!(
604            res.unwrap().1,
605            SelectStatement {
606                tables: vec![Table::from("PaperTag")],
607                distinct: true,
608                fields: columns(&["tag"]),
609                where_clause: expected_where_cond,
610                ..Default::default()
611            }
612        );
613    }
614
615    #[test]
616    fn simple_condition_expr() {
617        let qstring = "select infoJson from PaperStorage where paperId=? and paperStorageId=?;";
618
619        let res = selection(qstring.as_bytes());
620
621        let left_ct = ConditionTree {
622            left: Box::new(Base(Field(Column::from("paperId")))),
623            right: Box::new(Base(Literal(Literal::Placeholder))),
624            operator: Operator::Equal,
625        };
626        let left_comp = Box::new(ComparisonOp(left_ct));
627        let right_ct = ConditionTree {
628            left: Box::new(Base(Field(Column::from("paperStorageId")))),
629            right: Box::new(Base(Literal(Literal::Placeholder))),
630            operator: Operator::Equal,
631        };
632        let right_comp = Box::new(ComparisonOp(right_ct));
633        let expected_where_cond = Some(LogicalOp(ConditionTree {
634            left: left_comp,
635            right: right_comp,
636            operator: Operator::And,
637        }));
638        assert_eq!(
639            res.unwrap().1,
640            SelectStatement {
641                tables: vec![Table::from("PaperStorage")],
642                fields: columns(&["infoJson"]),
643                where_clause: expected_where_cond,
644                ..Default::default()
645            }
646        );
647    }
648
649    #[test]
650    fn where_and_limit_clauses() {
651        let qstring = "select * from users where id = ? limit 10\n";
652        let res = selection(qstring.as_bytes());
653
654        let expected_lim = Some(LimitClause {
655            limit: 10,
656            offset: 0,
657        });
658        let ct = ConditionTree {
659            left: Box::new(Base(Field(Column::from("id")))),
660            right: Box::new(Base(Literal(Literal::Placeholder))),
661            operator: Operator::Equal,
662        };
663        let expected_where_cond = Some(ComparisonOp(ct));
664
665        assert_eq!(
666            res.unwrap().1,
667            SelectStatement {
668                tables: vec![Table::from("users")],
669                fields: vec![FieldDefinitionExpression::All],
670                where_clause: expected_where_cond,
671                limit: expected_lim,
672                ..Default::default()
673            }
674        );
675    }
676
677    #[test]
678    fn aggregation_column() {
679        let qstring = "SELECT max(addr_id) FROM address;";
680
681        let res = selection(qstring.as_bytes());
682        let agg_expr = FunctionExpression::Max(FunctionArguments::Column(Column::from("addr_id")));
683        assert_eq!(
684            res.unwrap().1,
685            SelectStatement {
686                tables: vec![Table::from("address")],
687                fields: vec![FieldDefinitionExpression::Col(Column {
688                    name: String::from("max(addr_id)"),
689                    alias: None,
690                    table: None,
691                    function: Some(Box::new(agg_expr)),
692                }),],
693                ..Default::default()
694            }
695        );
696    }
697
698    #[test]
699    fn aggregation_column_with_alias() {
700        let qstring = "SELECT max(addr_id) AS max_addr FROM address;";
701
702        let res = selection(qstring.as_bytes());
703        let agg_expr = FunctionExpression::Max(FunctionArguments::Column(Column::from("addr_id")));
704        let expected_stmt = SelectStatement {
705            tables: vec![Table::from("address")],
706            fields: vec![FieldDefinitionExpression::Col(Column {
707                name: String::from("max_addr"),
708                alias: Some(String::from("max_addr")),
709                table: None,
710                function: Some(Box::new(agg_expr)),
711            })],
712            ..Default::default()
713        };
714        assert_eq!(res.unwrap().1, expected_stmt);
715    }
716
717    #[test]
718    fn count_all() {
719        let qstring = "SELECT COUNT(*) FROM votes GROUP BY aid;";
720
721        let res = selection(qstring.as_bytes());
722        let agg_expr = FunctionExpression::CountStar;
723        let expected_stmt = SelectStatement {
724            tables: vec![Table::from("votes")],
725            fields: vec![FieldDefinitionExpression::Col(Column {
726                name: String::from("count(*)"),
727                alias: None,
728                table: None,
729                function: Some(Box::new(agg_expr)),
730            })],
731            group_by: Some(GroupByClause {
732                columns: vec![Column::from("aid")],
733                having: None,
734            }),
735            ..Default::default()
736        };
737        assert_eq!(res.unwrap().1, expected_stmt);
738    }
739
740    #[test]
741    fn count_distinct() {
742        let qstring = "SELECT COUNT(DISTINCT vote_id) FROM votes GROUP BY aid;";
743
744        let res = selection(qstring.as_bytes());
745        let agg_expr =
746            FunctionExpression::Count(FunctionArguments::Column(Column::from("vote_id")), true);
747        let expected_stmt = SelectStatement {
748            tables: vec![Table::from("votes")],
749            fields: vec![FieldDefinitionExpression::Col(Column {
750                name: String::from("count(distinct vote_id)"),
751                alias: None,
752                table: None,
753                function: Some(Box::new(agg_expr)),
754            })],
755            group_by: Some(GroupByClause {
756                columns: vec![Column::from("aid")],
757                having: None,
758            }),
759            ..Default::default()
760        };
761        assert_eq!(res.unwrap().1, expected_stmt);
762    }
763
764    #[test]
765    fn count_filter() {
766        let qstring =
767            "SELECT COUNT(CASE WHEN vote_id > 10 THEN vote_id END) FROM votes GROUP BY aid;";
768        let res = selection(qstring.as_bytes());
769
770        let filter_cond = ComparisonOp(ConditionTree {
771            left: Box::new(Base(Field(Column::from("vote_id")))),
772            right: Box::new(Base(Literal(Literal::Integer(10.into())))),
773            operator: Operator::Greater,
774        });
775        let agg_expr = FunctionExpression::Count(
776            FunctionArguments::Conditional(CaseWhenExpression {
777                then_expr: ColumnOrLiteral::Column(Column::from("vote_id")),
778                else_expr: None,
779                condition: filter_cond,
780            }),
781            false,
782        );
783        let expected_stmt = SelectStatement {
784            tables: vec![Table::from("votes")],
785            fields: vec![FieldDefinitionExpression::Col(Column {
786                name: format!("{}", agg_expr),
787                alias: None,
788                table: None,
789                function: Some(Box::new(agg_expr)),
790            })],
791            group_by: Some(GroupByClause {
792                columns: vec![Column::from("aid")],
793                having: None,
794            }),
795            ..Default::default()
796        };
797        assert_eq!(res.unwrap().1, expected_stmt);
798    }
799
800    #[test]
801    fn sum_filter() {
802        let qstring = "SELECT SUM(CASE WHEN sign = 1 THEN vote_id END) FROM votes GROUP BY aid;";
803
804        let res = selection(qstring.as_bytes());
805
806        let filter_cond = ComparisonOp(ConditionTree {
807            left: Box::new(Base(Field(Column::from("sign")))),
808            right: Box::new(Base(Literal(Literal::Integer(1.into())))),
809            operator: Operator::Equal,
810        });
811        let agg_expr = FunctionExpression::Sum(
812            FunctionArguments::Conditional(CaseWhenExpression {
813                then_expr: ColumnOrLiteral::Column(Column::from("vote_id")),
814                else_expr: None,
815                condition: filter_cond,
816            }),
817            false,
818        );
819        let expected_stmt = SelectStatement {
820            tables: vec![Table::from("votes")],
821            fields: vec![FieldDefinitionExpression::Col(Column {
822                name: format!("{}", agg_expr),
823                alias: None,
824                table: None,
825                function: Some(Box::new(agg_expr)),
826            })],
827            group_by: Some(GroupByClause {
828                columns: vec![Column::from("aid")],
829                having: None,
830            }),
831            ..Default::default()
832        };
833        assert_eq!(res.unwrap().1, expected_stmt);
834    }
835
836    #[test]
837    fn sum_filter_else_literal() {
838        let qstring =
839            "SELECT SUM(CASE WHEN sign = 1 THEN vote_id ELSE 6 END) FROM votes GROUP BY aid;";
840
841        let res = selection(qstring.as_bytes());
842
843        let filter_cond = ComparisonOp(ConditionTree {
844            left: Box::new(Base(Field(Column::from("sign")))),
845            right: Box::new(Base(Literal(Literal::Integer(1.into())))),
846            operator: Operator::Equal,
847        });
848        let agg_expr = FunctionExpression::Sum(
849            FunctionArguments::Conditional(CaseWhenExpression {
850                then_expr: ColumnOrLiteral::Column(Column::from("vote_id")),
851                else_expr: Some(ColumnOrLiteral::Literal(Literal::Integer(6))),
852                condition: filter_cond,
853            }),
854            false,
855        );
856        let expected_stmt = SelectStatement {
857            tables: vec![Table::from("votes")],
858            fields: vec![FieldDefinitionExpression::Col(Column {
859                name: format!("{}", agg_expr),
860                alias: None,
861                table: None,
862                function: Some(Box::new(agg_expr)),
863            })],
864            group_by: Some(GroupByClause {
865                columns: vec![Column::from("aid")],
866                having: None,
867            }),
868            ..Default::default()
869        };
870        assert_eq!(res.unwrap().1, expected_stmt);
871    }
872
873    #[test]
874    fn count_filter_lobsters() {
875        let qstring = "SELECT
876            COUNT(CASE WHEN votes.story_id IS NULL AND votes.vote = 0 THEN votes.vote END) as votes
877            FROM votes
878            GROUP BY votes.comment_id;";
879
880        let res = selection(qstring.as_bytes());
881
882        let filter_cond = LogicalOp(ConditionTree {
883            left: Box::new(ComparisonOp(ConditionTree {
884                left: Box::new(Base(Field(Column::from("votes.story_id")))),
885                right: Box::new(Base(Literal(Literal::Null))),
886                operator: Operator::Equal,
887            })),
888            right: Box::new(ComparisonOp(ConditionTree {
889                left: Box::new(Base(Field(Column::from("votes.vote")))),
890                right: Box::new(Base(Literal(Literal::Integer(0)))),
891                operator: Operator::Equal,
892            })),
893            operator: Operator::And,
894        });
895        let agg_expr = FunctionExpression::Count(
896            FunctionArguments::Conditional(CaseWhenExpression {
897                then_expr: ColumnOrLiteral::Column(Column::from("votes.vote")),
898                else_expr: None,
899                condition: filter_cond,
900            }),
901            false,
902        );
903        let expected_stmt = SelectStatement {
904            tables: vec![Table::from("votes")],
905            fields: vec![FieldDefinitionExpression::Col(Column {
906                name: String::from("votes"),
907                alias: Some(String::from("votes")),
908                table: None,
909                function: Some(Box::new(agg_expr)),
910            })],
911            group_by: Some(GroupByClause {
912                columns: vec![Column::from("votes.comment_id")],
913                having: None,
914            }),
915            ..Default::default()
916        };
917        assert_eq!(res.unwrap().1, expected_stmt);
918    }
919
920    #[test]
921    fn moderately_complex_selection() {
922        let qstring = "SELECT * FROM item, author WHERE item.i_a_id = author.a_id AND \
923                       item.i_subject = ? ORDER BY item.i_title limit 50;";
924
925        let res = selection(qstring.as_bytes());
926        let expected_where_cond = Some(LogicalOp(ConditionTree {
927            left: Box::new(ComparisonOp(ConditionTree {
928                left: Box::new(Base(Field(Column::from("item.i_a_id")))),
929                right: Box::new(Base(Field(Column::from("author.a_id")))),
930                operator: Operator::Equal,
931            })),
932            right: Box::new(ComparisonOp(ConditionTree {
933                left: Box::new(Base(Field(Column::from("item.i_subject")))),
934                right: Box::new(Base(Literal(Literal::Placeholder))),
935                operator: Operator::Equal,
936            })),
937            operator: Operator::And,
938        }));
939        assert_eq!(
940            res.unwrap().1,
941            SelectStatement {
942                tables: vec![Table::from("item"), Table::from("author")],
943                fields: vec![FieldDefinitionExpression::All],
944                where_clause: expected_where_cond,
945                order: Some(OrderClause {
946                    columns: vec![("item.i_title".into(), OrderType::OrderAscending)],
947                }),
948                limit: Some(LimitClause {
949                    limit: 50,
950                    offset: 0,
951                }),
952                ..Default::default()
953            }
954        );
955    }
956
957    #[test]
958    fn simple_joins() {
959        let qstring = "select paperId from PaperConflict join PCMember using (contactId);";
960
961        let res = selection(qstring.as_bytes());
962        let expected_stmt = SelectStatement {
963            tables: vec![Table::from("PaperConflict")],
964            fields: columns(&["paperId"]),
965            join: vec![JoinClause {
966                operator: JoinOperator::Join,
967                right: JoinRightSide::Table(Table::from("PCMember")),
968                constraint: JoinConstraint::Using(vec![Column::from("contactId")]),
969            }],
970            ..Default::default()
971        };
972        assert_eq!(res.unwrap().1, expected_stmt);
973
974        // slightly simplified from
975        // "select PCMember.contactId, group_concat(reviewType separator '')
976        // from PCMember left join PaperReview on (PCMember.contactId=PaperReview.contactId)
977        // group by PCMember.contactId"
978        let qstring = "select PCMember.contactId \
979                       from PCMember \
980                       join PaperReview on (PCMember.contactId=PaperReview.contactId) \
981                       order by contactId;";
982
983        let res = selection(qstring.as_bytes());
984        let ct = ConditionTree {
985            left: Box::new(Base(Field(Column::from("PCMember.contactId")))),
986            right: Box::new(Base(Field(Column::from("PaperReview.contactId")))),
987            operator: Operator::Equal,
988        };
989        let join_cond = ConditionExpression::ComparisonOp(ct);
990        let expected = SelectStatement {
991            tables: vec![Table::from("PCMember")],
992            fields: columns(&["PCMember.contactId"]),
993            join: vec![JoinClause {
994                operator: JoinOperator::Join,
995                right: JoinRightSide::Table(Table::from("PaperReview")),
996                constraint: JoinConstraint::On(join_cond),
997            }],
998            order: Some(OrderClause {
999                columns: vec![("contactId".into(), OrderType::OrderAscending)],
1000            }),
1001            ..Default::default()
1002        };
1003        assert_eq!(res.unwrap().1, expected);
1004
1005        // Same as above, but no brackets
1006        let qstring = "select PCMember.contactId \
1007                       from PCMember \
1008                       join PaperReview on PCMember.contactId=PaperReview.contactId \
1009                       order by contactId;";
1010        let res = selection(qstring.as_bytes());
1011        assert_eq!(res.unwrap().1, expected);
1012    }
1013
1014    #[test]
1015    fn multi_join() {
1016        // simplified from
1017        // "select max(conflictType), PaperReview.contactId as reviewer, PCMember.contactId as
1018        //  pcMember, ChairAssistant.contactId as assistant, Chair.contactId as chair,
1019        //  max(PaperReview.reviewNeedsSubmit) as reviewNeedsSubmit from ContactInfo
1020        //  left join PaperReview using (contactId) left join PaperConflict using (contactId)
1021        //  left join PCMember using (contactId) left join ChairAssistant using (contactId)
1022        //  left join Chair using (contactId) where ContactInfo.contactId=?
1023        //  group by ContactInfo.contactId;";
1024        let qstring = "select PCMember.contactId, ChairAssistant.contactId, \
1025                       Chair.contactId from ContactInfo left join PaperReview using (contactId) \
1026                       left join PaperConflict using (contactId) left join PCMember using \
1027                       (contactId) left join ChairAssistant using (contactId) left join Chair \
1028                       using (contactId) where ContactInfo.contactId=?;";
1029
1030        let res = selection(qstring.as_bytes());
1031        let ct = ConditionTree {
1032            left: Box::new(Base(Field(Column::from("ContactInfo.contactId")))),
1033            right: Box::new(Base(Literal(Literal::Placeholder))),
1034            operator: Operator::Equal,
1035        };
1036        let expected_where_cond = Some(ComparisonOp(ct));
1037        let mkjoin = |tbl: &str, col: &str| -> JoinClause {
1038            JoinClause {
1039                operator: JoinOperator::LeftJoin,
1040                right: JoinRightSide::Table(Table::from(tbl)),
1041                constraint: JoinConstraint::Using(vec![Column::from(col)]),
1042            }
1043        };
1044        assert_eq!(
1045            res.unwrap().1,
1046            SelectStatement {
1047                tables: vec![Table::from("ContactInfo")],
1048                fields: columns(&[
1049                    "PCMember.contactId",
1050                    "ChairAssistant.contactId",
1051                    "Chair.contactId"
1052                ]),
1053                join: vec![
1054                    mkjoin("PaperReview", "contactId"),
1055                    mkjoin("PaperConflict", "contactId"),
1056                    mkjoin("PCMember", "contactId"),
1057                    mkjoin("ChairAssistant", "contactId"),
1058                    mkjoin("Chair", "contactId"),
1059                ],
1060                where_clause: expected_where_cond,
1061                ..Default::default()
1062            }
1063        );
1064    }
1065
1066    #[test]
1067    fn nested_select() {
1068        let qstr = "SELECT ol_i_id FROM orders, order_line \
1069                    WHERE orders.o_c_id IN (SELECT o_c_id FROM orders, order_line \
1070                    WHERE orders.o_id = order_line.ol_o_id);";
1071
1072        let res = selection(qstr.as_bytes());
1073        let inner_where_clause = ComparisonOp(ConditionTree {
1074            left: Box::new(Base(Field(Column::from("orders.o_id")))),
1075            right: Box::new(Base(Field(Column::from("order_line.ol_o_id")))),
1076            operator: Operator::Equal,
1077        });
1078
1079        let inner_select = SelectStatement {
1080            tables: vec![Table::from("orders"), Table::from("order_line")],
1081            fields: columns(&["o_c_id"]),
1082            where_clause: Some(inner_where_clause),
1083            ..Default::default()
1084        };
1085
1086        let outer_where_clause = ComparisonOp(ConditionTree {
1087            left: Box::new(Base(Field(Column::from("orders.o_c_id")))),
1088            right: Box::new(Base(NestedSelect(Box::new(inner_select)))),
1089            operator: Operator::In,
1090        });
1091
1092        let outer_select = SelectStatement {
1093            tables: vec![Table::from("orders"), Table::from("order_line")],
1094            fields: columns(&["ol_i_id"]),
1095            where_clause: Some(outer_where_clause),
1096            ..Default::default()
1097        };
1098
1099        assert_eq!(res.unwrap().1, outer_select);
1100    }
1101
1102    #[test]
1103    fn recursive_nested_select() {
1104        let qstr = "SELECT ol_i_id FROM orders, order_line WHERE orders.o_c_id \
1105                    IN (SELECT o_c_id FROM orders, order_line \
1106                    WHERE orders.o_id = order_line.ol_o_id \
1107                    AND orders.o_id > (SELECT MAX(o_id) FROM orders));";
1108
1109        let res = selection(qstr.as_bytes());
1110
1111        let agg_expr = FunctionExpression::Max(FunctionArguments::Column(Column::from("o_id")));
1112        let recursive_select = SelectStatement {
1113            tables: vec![Table::from("orders")],
1114            fields: vec![FieldDefinitionExpression::Col(Column {
1115                name: String::from("max(o_id)"),
1116                alias: None,
1117                table: None,
1118                function: Some(Box::new(agg_expr)),
1119            })],
1120            ..Default::default()
1121        };
1122
1123        let cop1 = ComparisonOp(ConditionTree {
1124            left: Box::new(Base(Field(Column::from("orders.o_id")))),
1125            right: Box::new(Base(Field(Column::from("order_line.ol_o_id")))),
1126            operator: Operator::Equal,
1127        });
1128
1129        let cop2 = ComparisonOp(ConditionTree {
1130            left: Box::new(Base(Field(Column::from("orders.o_id")))),
1131            right: Box::new(Base(NestedSelect(Box::new(recursive_select)))),
1132            operator: Operator::Greater,
1133        });
1134
1135        let inner_where_clause = LogicalOp(ConditionTree {
1136            left: Box::new(cop1),
1137            right: Box::new(cop2),
1138            operator: Operator::And,
1139        });
1140
1141        let inner_select = SelectStatement {
1142            tables: vec![Table::from("orders"), Table::from("order_line")],
1143            fields: columns(&["o_c_id"]),
1144            where_clause: Some(inner_where_clause),
1145            ..Default::default()
1146        };
1147
1148        let outer_where_clause = ComparisonOp(ConditionTree {
1149            left: Box::new(Base(Field(Column::from("orders.o_c_id")))),
1150            right: Box::new(Base(NestedSelect(Box::new(inner_select)))),
1151            operator: Operator::In,
1152        });
1153
1154        let outer_select = SelectStatement {
1155            tables: vec![Table::from("orders"), Table::from("order_line")],
1156            fields: columns(&["ol_i_id"]),
1157            where_clause: Some(outer_where_clause),
1158            ..Default::default()
1159        };
1160
1161        assert_eq!(res.unwrap().1, outer_select);
1162    }
1163
1164    #[test]
1165    fn join_against_nested_select() {
1166        let t0 = b"(SELECT ol_i_id FROM order_line)";
1167        let t1 = b"(SELECT ol_i_id FROM order_line) AS ids";
1168
1169        assert!(join_rhs(t0).is_ok());
1170        assert!(join_rhs(t1).is_ok());
1171
1172        let t0 = b"JOIN (SELECT ol_i_id FROM order_line) ON (orders.o_id = ol_i_id)";
1173        let t1 = b"JOIN (SELECT ol_i_id FROM order_line) AS ids ON (orders.o_id = ids.ol_i_id)";
1174
1175        assert!(join_clause(t0).is_ok());
1176        assert!(join_clause(t1).is_ok());
1177
1178        let qstr_with_alias = "SELECT o_id, ol_i_id FROM orders JOIN \
1179                               (SELECT ol_i_id FROM order_line) AS ids \
1180                               ON (orders.o_id = ids.ol_i_id);";
1181        let res = selection(qstr_with_alias.as_bytes());
1182
1183        // N.B.: Don't alias the inner select to `inner`, which is, well, a SQL keyword!
1184        let inner_select = SelectStatement {
1185            tables: vec![Table::from("order_line")],
1186            fields: columns(&["ol_i_id"]),
1187            ..Default::default()
1188        };
1189
1190        let outer_select = SelectStatement {
1191            tables: vec![Table::from("orders")],
1192            fields: columns(&["o_id", "ol_i_id"]),
1193            join: vec![JoinClause {
1194                operator: JoinOperator::Join,
1195                right: JoinRightSide::NestedSelect(Box::new(inner_select), Some("ids".into())),
1196                constraint: JoinConstraint::On(ComparisonOp(ConditionTree {
1197                    operator: Operator::Equal,
1198                    left: Box::new(Base(Field(Column::from("orders.o_id")))),
1199                    right: Box::new(Base(Field(Column::from("ids.ol_i_id")))),
1200                })),
1201            }],
1202            ..Default::default()
1203        };
1204
1205        assert_eq!(res.unwrap().1, outer_select);
1206    }
1207
1208    #[test]
1209    fn project_arithmetic_expressions() {
1210        use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator};
1211
1212        let qstr = "SELECT MAX(o_id)-3333 FROM orders;";
1213        let res = selection(qstr.as_bytes());
1214
1215        let expected = SelectStatement {
1216            tables: vec![Table::from("orders")],
1217            fields: vec![FieldDefinitionExpression::Value(
1218                FieldValueExpression::Arithmetic(ArithmeticExpression {
1219                    alias: None,
1220                    op: ArithmeticOperator::Subtract,
1221                    left: ArithmeticBase::Column(Column {
1222                        name: String::from("max(o_id)"),
1223                        alias: None,
1224                        table: None,
1225                        function: Some(Box::new(FunctionExpression::Max(
1226                            FunctionArguments::Column("o_id".into()),
1227                        ))),
1228                    }),
1229                    right: ArithmeticBase::Scalar(3333.into()),
1230                }),
1231            )],
1232            ..Default::default()
1233        };
1234
1235        assert_eq!(res.unwrap().1, expected);
1236    }
1237
1238    #[test]
1239    fn project_arithmetic_expressions_with_aliases() {
1240        use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator};
1241
1242        let qstr = "SELECT max(o_id) * 2 as double_max FROM orders;";
1243        let res = selection(qstr.as_bytes());
1244
1245        let expected = SelectStatement {
1246            tables: vec![Table::from("orders")],
1247            fields: vec![FieldDefinitionExpression::Value(
1248                FieldValueExpression::Arithmetic(ArithmeticExpression {
1249                    alias: Some(String::from("double_max")),
1250                    op: ArithmeticOperator::Multiply,
1251                    left: ArithmeticBase::Column(Column {
1252                        name: String::from("max(o_id)"),
1253                        alias: None,
1254                        table: None,
1255                        function: Some(Box::new(FunctionExpression::Max(
1256                            FunctionArguments::Column("o_id".into()),
1257                        ))),
1258                    }),
1259                    right: ArithmeticBase::Scalar(2.into()),
1260                }),
1261            )],
1262            ..Default::default()
1263        };
1264
1265        assert_eq!(res.unwrap().1, expected);
1266    }
1267
1268    #[test]
1269    fn where_in_clause() {
1270        let qstr = "SELECT `auth_permission`.`content_type_id`, `auth_permission`.`codename`
1271                    FROM `auth_permission`
1272                    JOIN `django_content_type`
1273                      ON ( `auth_permission`.`content_type_id` = `django_content_type`.`id` )
1274                    WHERE `auth_permission`.`content_type_id` IN (0);";
1275        let res = selection(qstr.as_bytes());
1276
1277        let expected_where_clause = Some(ComparisonOp(ConditionTree {
1278            left: Box::new(Base(Field(Column::from("auth_permission.content_type_id")))),
1279            right: Box::new(Base(LiteralList(vec![0.into()]))),
1280            operator: Operator::In,
1281        }));
1282
1283        let expected = SelectStatement {
1284            tables: vec![Table::from("auth_permission")],
1285            fields: vec![
1286                FieldDefinitionExpression::Col(Column::from("auth_permission.content_type_id")),
1287                FieldDefinitionExpression::Col(Column::from("auth_permission.codename")),
1288            ],
1289            join: vec![JoinClause {
1290                operator: JoinOperator::Join,
1291                right: JoinRightSide::Table(Table::from("django_content_type")),
1292                constraint: JoinConstraint::On(ComparisonOp(ConditionTree {
1293                    operator: Operator::Equal,
1294                    left: Box::new(Base(Field(Column::from("auth_permission.content_type_id")))),
1295                    right: Box::new(Base(Field(Column::from("django_content_type.id")))),
1296                })),
1297            }],
1298            where_clause: expected_where_clause,
1299            ..Default::default()
1300        };
1301
1302        assert_eq!(res.unwrap().1, expected);
1303    }
1304}