1use nom::character::complete::{multispace0, multispace1};
2use std::fmt;
3use std::str;
4
5use column::Column;
6use common::FieldDefinitionExpression;
7use common::{
8 as_alias, field_definition_expr, field_list, statement_terminator, table_list, table_reference,
9 unsigned_number,
10};
11use condition::{condition_expr, ConditionExpression};
12use join::{join_operator, JoinConstraint, JoinOperator, JoinRightSide};
13use nom::branch::alt;
14use nom::bytes::complete::{tag, tag_no_case};
15use nom::combinator::{map, opt};
16use nom::multi::many0;
17use nom::sequence::{delimited, preceded, terminated, tuple};
18use nom::IResult;
19use order::{order_clause, OrderClause};
20use table::Table;
21
22#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
23pub struct GroupByClause {
24 pub columns: Vec<Column>,
25 pub having: Option<ConditionExpression>,
26}
27
28impl fmt::Display for GroupByClause {
29 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30 write!(f, "GROUP BY ")?;
31 write!(
32 f,
33 "{}",
34 self.columns
35 .iter()
36 .map(|c| format!("{}", c))
37 .collect::<Vec<_>>()
38 .join(", ")
39 )?;
40 if let Some(ref having) = self.having {
41 write!(f, " HAVING {}", having)?;
42 }
43 Ok(())
44 }
45}
46
47#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
48pub struct JoinClause {
49 pub operator: JoinOperator,
50 pub right: JoinRightSide,
51 pub constraint: JoinConstraint,
52}
53
54impl fmt::Display for JoinClause {
55 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
56 write!(f, "{}", self.operator)?;
57 write!(f, " {}", self.right)?;
58 write!(f, " {}", self.constraint)?;
59 Ok(())
60 }
61}
62
63#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
64pub struct LimitClause {
65 pub limit: u64,
66 pub offset: u64,
67}
68
69impl fmt::Display for LimitClause {
70 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
71 write!(f, "LIMIT {}", self.limit)?;
72 if self.offset > 0 {
73 write!(f, " OFFSET {}", self.offset)?;
74 }
75 Ok(())
76 }
77}
78
79#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
80pub struct SelectStatement {
81 pub tables: Vec<Table>,
82 pub distinct: bool,
83 pub fields: Vec<FieldDefinitionExpression>,
84 pub join: Vec<JoinClause>,
85 pub where_clause: Option<ConditionExpression>,
86 pub group_by: Option<GroupByClause>,
87 pub order: Option<OrderClause>,
88 pub limit: Option<LimitClause>,
89}
90
91impl fmt::Display for SelectStatement {
92 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93 write!(f, "SELECT ")?;
94 if self.distinct {
95 write!(f, "DISTINCT ")?;
96 }
97 write!(
98 f,
99 "{}",
100 self.fields
101 .iter()
102 .map(|field| format!("{}", field))
103 .collect::<Vec<_>>()
104 .join(", ")
105 )?;
106
107 if self.tables.len() > 0 {
108 write!(f, " FROM ")?;
109 write!(
110 f,
111 "{}",
112 self.tables
113 .iter()
114 .map(|table| format!("{}", table))
115 .collect::<Vec<_>>()
116 .join(", ")
117 )?;
118 }
119 for jc in &self.join {
120 write!(f, " {}", jc)?;
121 }
122 if let Some(ref where_clause) = self.where_clause {
123 write!(f, " WHERE ")?;
124 write!(f, "{}", where_clause)?;
125 }
126 if let Some(ref group_by) = self.group_by {
127 write!(f, " {}", group_by)?;
128 }
129 if let Some(ref order) = self.order {
130 write!(f, " {}", order)?;
131 }
132 if let Some(ref limit) = self.limit {
133 write!(f, " {}", limit)?;
134 }
135 Ok(())
136 }
137}
138
139fn having_clause(i: &[u8]) -> IResult<&[u8], ConditionExpression> {
140 let (remaining_input, (_, _, _, ce)) = tuple((
141 multispace0,
142 tag_no_case("having"),
143 multispace1,
144 condition_expr,
145 ))(i)?;
146
147 Ok((remaining_input, ce))
148}
149
150pub fn group_by_clause(i: &[u8]) -> IResult<&[u8], GroupByClause> {
152 let (remaining_input, (_, _, _, columns, having)) = tuple((
153 multispace0,
154 tag_no_case("group by"),
155 multispace1,
156 field_list,
157 opt(having_clause),
158 ))(i)?;
159
160 Ok((remaining_input, GroupByClause { columns, having }))
161}
162
163fn offset(i: &[u8]) -> IResult<&[u8], u64> {
164 let (remaining_input, (_, _, _, val)) = tuple((
165 multispace0,
166 tag_no_case("offset"),
167 multispace1,
168 unsigned_number,
169 ))(i)?;
170
171 Ok((remaining_input, val))
172}
173
174pub fn limit_clause(i: &[u8]) -> IResult<&[u8], LimitClause> {
176 let (remaining_input, (_, _, _, limit, opt_offset)) = tuple((
177 multispace0,
178 tag_no_case("limit"),
179 multispace1,
180 unsigned_number,
181 opt(offset),
182 ))(i)?;
183 let offset = match opt_offset {
184 None => 0,
185 Some(v) => v,
186 };
187
188 Ok((remaining_input, LimitClause { limit, offset }))
189}
190
191fn join_constraint(i: &[u8]) -> IResult<&[u8], JoinConstraint> {
192 let using_clause = map(
193 tuple((
194 tag_no_case("using"),
195 multispace1,
196 delimited(
197 terminated(tag("("), multispace0),
198 field_list,
199 preceded(multispace0, tag(")")),
200 ),
201 )),
202 |t| JoinConstraint::Using(t.2),
203 );
204 let on_condition = alt((
205 delimited(
206 terminated(tag("("), multispace0),
207 condition_expr,
208 preceded(multispace0, tag(")")),
209 ),
210 condition_expr,
211 ));
212 let on_clause = map(tuple((tag_no_case("on"), multispace1, on_condition)), |t| {
213 JoinConstraint::On(t.2)
214 });
215
216 alt((using_clause, on_clause))(i)
217}
218
219fn join_clause(i: &[u8]) -> IResult<&[u8], JoinClause> {
221 let (remaining_input, (_, _natural, operator, _, right, _, constraint)) = tuple((
222 multispace0,
223 opt(terminated(tag_no_case("natural"), multispace1)),
224 join_operator,
225 multispace1,
226 join_rhs,
227 multispace1,
228 join_constraint,
229 ))(i)?;
230
231 Ok((
232 remaining_input,
233 JoinClause {
234 operator,
235 right,
236 constraint,
237 },
238 ))
239}
240
241fn join_rhs(i: &[u8]) -> IResult<&[u8], JoinRightSide> {
242 let nested_select = map(
243 tuple((
244 delimited(tag("("), nested_selection, tag(")")),
245 opt(as_alias),
246 )),
247 |t| JoinRightSide::NestedSelect(Box::new(t.0), t.1.map(String::from)),
248 );
249 let nested_join = map(delimited(tag("("), join_clause, tag(")")), |nj| {
250 JoinRightSide::NestedJoin(Box::new(nj))
251 });
252 let table = map(table_reference, |t| JoinRightSide::Table(t));
253 let tables = map(delimited(tag("("), table_list, tag(")")), |tables| {
254 JoinRightSide::Tables(tables)
255 });
256 alt((nested_select, nested_join, table, tables))(i)
257}
258
259pub fn where_clause(i: &[u8]) -> IResult<&[u8], ConditionExpression> {
261 let (remaining_input, (_, _, _, where_condition)) = tuple((
262 multispace0,
263 tag_no_case("where"),
264 multispace1,
265 condition_expr,
266 ))(i)?;
267
268 Ok((remaining_input, where_condition))
269}
270
271pub fn selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
273 terminated(nested_selection, statement_terminator)(i)
274}
275
276pub fn nested_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
277 let (
278 remaining_input,
279 (_, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit),
280 ) = tuple((
281 tag_no_case("select"),
282 multispace1,
283 opt(tag_no_case("distinct")),
284 multispace0,
285 field_definition_expr,
286 delimited(multispace0, tag_no_case("from"), multispace0),
287 table_list,
288 many0(join_clause),
289 opt(where_clause),
290 opt(group_by_clause),
291 opt(order_clause),
292 opt(limit_clause),
293 ))(i)?;
294 Ok((
295 remaining_input,
296 SelectStatement {
297 tables,
298 distinct: distinct.is_some(),
299 fields,
300 join,
301 where_clause,
302 group_by,
303 order,
304 limit,
305 },
306 ))
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use case::{CaseWhenExpression, ColumnOrLiteral};
313 use column::{Column, FunctionArguments, FunctionExpression};
314 use common::{FieldDefinitionExpression, FieldValueExpression, Literal, Operator};
315 use condition::ConditionBase::*;
316 use condition::ConditionExpression::*;
317 use condition::ConditionTree;
318 use order::OrderType;
319 use table::Table;
320
321 fn columns(cols: &[&str]) -> Vec<FieldDefinitionExpression> {
322 cols.iter()
323 .map(|c| FieldDefinitionExpression::Col(Column::from(*c)))
324 .collect()
325 }
326
327 #[test]
328 fn simple_select() {
329 let qstring = "SELECT id, name FROM users;";
330
331 let res = selection(qstring.as_bytes());
332 assert_eq!(
333 res.unwrap().1,
334 SelectStatement {
335 tables: vec![Table::from("users")],
336 fields: columns(&["id", "name"]),
337 ..Default::default()
338 }
339 );
340 }
341
342 #[test]
343 fn more_involved_select() {
344 let qstring = "SELECT users.id, users.name FROM users;";
345
346 let res = selection(qstring.as_bytes());
347 assert_eq!(
348 res.unwrap().1,
349 SelectStatement {
350 tables: vec![Table::from("users")],
351 fields: columns(&["users.id", "users.name"]),
352 ..Default::default()
353 }
354 );
355 }
356
357 #[test]
358 fn select_literals() {
359 use common::Literal;
360
361 let qstring = "SELECT NULL, 1, \"foo\", CURRENT_TIME FROM users;";
362 let res = selection(qstring.as_bytes());
366 assert_eq!(
367 res.unwrap().1,
368 SelectStatement {
369 tables: vec![Table::from("users")],
370 fields: vec![
371 FieldDefinitionExpression::Value(FieldValueExpression::Literal(
372 Literal::Null.into(),
373 )),
374 FieldDefinitionExpression::Value(FieldValueExpression::Literal(
375 Literal::Integer(1).into(),
376 )),
377 FieldDefinitionExpression::Value(FieldValueExpression::Literal(
378 Literal::String("foo".to_owned()).into(),
379 )),
380 FieldDefinitionExpression::Value(FieldValueExpression::Literal(
381 Literal::CurrentTime.into(),
382 )),
383 ],
384 ..Default::default()
385 }
386 );
387 }
388
389 #[test]
390 fn select_all() {
391 let qstring = "SELECT * FROM users;";
392
393 let res = selection(qstring.as_bytes());
394 assert_eq!(
395 res.unwrap().1,
396 SelectStatement {
397 tables: vec![Table::from("users")],
398 fields: vec![FieldDefinitionExpression::All],
399 ..Default::default()
400 }
401 );
402 }
403
404 #[test]
405 fn select_all_in_table() {
406 let qstring = "SELECT users.* FROM users, votes;";
407
408 let res = selection(qstring.as_bytes());
409 assert_eq!(
410 res.unwrap().1,
411 SelectStatement {
412 tables: vec![Table::from("users"), Table::from("votes")],
413 fields: vec![FieldDefinitionExpression::AllInTable(String::from("users"))],
414 ..Default::default()
415 }
416 );
417 }
418
419 #[test]
420 fn spaces_optional() {
421 let qstring = "SELECT id,name FROM users;";
422
423 let res = selection(qstring.as_bytes());
424 assert_eq!(
425 res.unwrap().1,
426 SelectStatement {
427 tables: vec![Table::from("users")],
428 fields: columns(&["id", "name"]),
429 ..Default::default()
430 }
431 );
432 }
433
434 #[test]
435 fn case_sensitivity() {
436 let qstring_lc = "select id, name from users;";
437 let qstring_uc = "SELECT id, name FROM users;";
438
439 assert_eq!(
440 selection(qstring_lc.as_bytes()).unwrap(),
441 selection(qstring_uc.as_bytes()).unwrap()
442 );
443 }
444
445 #[test]
446 fn termination() {
447 let qstring_sem = "select id, name from users;";
448 let qstring_nosem = "select id, name from users";
449 let qstring_linebreak = "select id, name from users\n";
450
451 let r1 = selection(qstring_sem.as_bytes()).unwrap();
452 let r2 = selection(qstring_nosem.as_bytes()).unwrap();
453 let r3 = selection(qstring_linebreak.as_bytes()).unwrap();
454 assert_eq!(r1, r2);
455 assert_eq!(r2, r3);
456 }
457
458 #[test]
459 fn where_clause() {
460 let qstring = "select * from ContactInfo where email=?;";
461
462 let res = selection(qstring.as_bytes());
463
464 let expected_left = Base(Field(Column::from("email")));
465 let expected_where_cond = Some(ComparisonOp(ConditionTree {
466 left: Box::new(expected_left),
467 right: Box::new(Base(Literal(Literal::Placeholder))),
468 operator: Operator::Equal,
469 }));
470 assert_eq!(
471 res.unwrap().1,
472 SelectStatement {
473 tables: vec![Table::from("ContactInfo")],
474 fields: vec![FieldDefinitionExpression::All],
475 where_clause: expected_where_cond,
476 ..Default::default()
477 }
478 );
479 }
480
481 #[test]
482 fn limit_clause() {
483 let qstring1 = "select * from users limit 10\n";
484 let qstring2 = "select * from users limit 10 offset 10\n";
485
486 let expected_lim1 = LimitClause {
487 limit: 10,
488 offset: 0,
489 };
490 let expected_lim2 = LimitClause {
491 limit: 10,
492 offset: 10,
493 };
494
495 let res1 = selection(qstring1.as_bytes());
496 let res2 = selection(qstring2.as_bytes());
497 assert_eq!(res1.unwrap().1.limit, Some(expected_lim1));
498 assert_eq!(res2.unwrap().1.limit, Some(expected_lim2));
499 }
500
501 #[test]
502 fn table_alias() {
503 let qstring1 = "select * from PaperTag as t;";
504 let res1 = selection(qstring1.as_bytes());
507 assert_eq!(
508 res1.clone().unwrap().1,
509 SelectStatement {
510 tables: vec![Table {
511 name: String::from("PaperTag"),
512 alias: Some(String::from("t")),
513 },],
514 fields: vec![FieldDefinitionExpression::All],
515 ..Default::default()
516 }
517 );
518 }
521
522 #[test]
523 fn column_alias() {
524 let qstring1 = "select name as TagName from PaperTag;";
525 let qstring2 = "select PaperTag.name as TagName from PaperTag;";
526
527 let res1 = selection(qstring1.as_bytes());
528 assert_eq!(
529 res1.clone().unwrap().1,
530 SelectStatement {
531 tables: vec![Table::from("PaperTag")],
532 fields: vec![FieldDefinitionExpression::Col(Column {
533 name: String::from("name"),
534 alias: Some(String::from("TagName")),
535 table: None,
536 function: None,
537 }),],
538 ..Default::default()
539 }
540 );
541 let res2 = selection(qstring2.as_bytes());
542 assert_eq!(
543 res2.clone().unwrap().1,
544 SelectStatement {
545 tables: vec![Table::from("PaperTag")],
546 fields: vec![FieldDefinitionExpression::Col(Column {
547 name: String::from("name"),
548 alias: Some(String::from("TagName")),
549 table: Some(String::from("PaperTag")),
550 function: None,
551 }),],
552 ..Default::default()
553 }
554 );
555 }
556
557 #[test]
558 fn column_alias_no_as() {
559 let qstring1 = "select name TagName from PaperTag;";
560 let qstring2 = "select PaperTag.name TagName from PaperTag;";
561
562 let res1 = selection(qstring1.as_bytes());
563 assert_eq!(
564 res1.clone().unwrap().1,
565 SelectStatement {
566 tables: vec![Table::from("PaperTag")],
567 fields: vec![FieldDefinitionExpression::Col(Column {
568 name: String::from("name"),
569 alias: Some(String::from("TagName")),
570 table: None,
571 function: None,
572 }),],
573 ..Default::default()
574 }
575 );
576 let res2 = selection(qstring2.as_bytes());
577 assert_eq!(
578 res2.clone().unwrap().1,
579 SelectStatement {
580 tables: vec![Table::from("PaperTag")],
581 fields: vec![FieldDefinitionExpression::Col(Column {
582 name: String::from("name"),
583 alias: Some(String::from("TagName")),
584 table: Some(String::from("PaperTag")),
585 function: None,
586 }),],
587 ..Default::default()
588 }
589 );
590 }
591
592 #[test]
593 fn distinct() {
594 let qstring = "select distinct tag from PaperTag where paperId=?;";
595
596 let res = selection(qstring.as_bytes());
597 let expected_left = Base(Field(Column::from("paperId")));
598 let expected_where_cond = Some(ComparisonOp(ConditionTree {
599 left: Box::new(expected_left),
600 right: Box::new(Base(Literal(Literal::Placeholder))),
601 operator: Operator::Equal,
602 }));
603 assert_eq!(
604 res.unwrap().1,
605 SelectStatement {
606 tables: vec![Table::from("PaperTag")],
607 distinct: true,
608 fields: columns(&["tag"]),
609 where_clause: expected_where_cond,
610 ..Default::default()
611 }
612 );
613 }
614
615 #[test]
616 fn simple_condition_expr() {
617 let qstring = "select infoJson from PaperStorage where paperId=? and paperStorageId=?;";
618
619 let res = selection(qstring.as_bytes());
620
621 let left_ct = ConditionTree {
622 left: Box::new(Base(Field(Column::from("paperId")))),
623 right: Box::new(Base(Literal(Literal::Placeholder))),
624 operator: Operator::Equal,
625 };
626 let left_comp = Box::new(ComparisonOp(left_ct));
627 let right_ct = ConditionTree {
628 left: Box::new(Base(Field(Column::from("paperStorageId")))),
629 right: Box::new(Base(Literal(Literal::Placeholder))),
630 operator: Operator::Equal,
631 };
632 let right_comp = Box::new(ComparisonOp(right_ct));
633 let expected_where_cond = Some(LogicalOp(ConditionTree {
634 left: left_comp,
635 right: right_comp,
636 operator: Operator::And,
637 }));
638 assert_eq!(
639 res.unwrap().1,
640 SelectStatement {
641 tables: vec![Table::from("PaperStorage")],
642 fields: columns(&["infoJson"]),
643 where_clause: expected_where_cond,
644 ..Default::default()
645 }
646 );
647 }
648
649 #[test]
650 fn where_and_limit_clauses() {
651 let qstring = "select * from users where id = ? limit 10\n";
652 let res = selection(qstring.as_bytes());
653
654 let expected_lim = Some(LimitClause {
655 limit: 10,
656 offset: 0,
657 });
658 let ct = ConditionTree {
659 left: Box::new(Base(Field(Column::from("id")))),
660 right: Box::new(Base(Literal(Literal::Placeholder))),
661 operator: Operator::Equal,
662 };
663 let expected_where_cond = Some(ComparisonOp(ct));
664
665 assert_eq!(
666 res.unwrap().1,
667 SelectStatement {
668 tables: vec![Table::from("users")],
669 fields: vec![FieldDefinitionExpression::All],
670 where_clause: expected_where_cond,
671 limit: expected_lim,
672 ..Default::default()
673 }
674 );
675 }
676
677 #[test]
678 fn aggregation_column() {
679 let qstring = "SELECT max(addr_id) FROM address;";
680
681 let res = selection(qstring.as_bytes());
682 let agg_expr = FunctionExpression::Max(FunctionArguments::Column(Column::from("addr_id")));
683 assert_eq!(
684 res.unwrap().1,
685 SelectStatement {
686 tables: vec![Table::from("address")],
687 fields: vec![FieldDefinitionExpression::Col(Column {
688 name: String::from("max(addr_id)"),
689 alias: None,
690 table: None,
691 function: Some(Box::new(agg_expr)),
692 }),],
693 ..Default::default()
694 }
695 );
696 }
697
698 #[test]
699 fn aggregation_column_with_alias() {
700 let qstring = "SELECT max(addr_id) AS max_addr FROM address;";
701
702 let res = selection(qstring.as_bytes());
703 let agg_expr = FunctionExpression::Max(FunctionArguments::Column(Column::from("addr_id")));
704 let expected_stmt = SelectStatement {
705 tables: vec![Table::from("address")],
706 fields: vec![FieldDefinitionExpression::Col(Column {
707 name: String::from("max_addr"),
708 alias: Some(String::from("max_addr")),
709 table: None,
710 function: Some(Box::new(agg_expr)),
711 })],
712 ..Default::default()
713 };
714 assert_eq!(res.unwrap().1, expected_stmt);
715 }
716
717 #[test]
718 fn count_all() {
719 let qstring = "SELECT COUNT(*) FROM votes GROUP BY aid;";
720
721 let res = selection(qstring.as_bytes());
722 let agg_expr = FunctionExpression::CountStar;
723 let expected_stmt = SelectStatement {
724 tables: vec![Table::from("votes")],
725 fields: vec![FieldDefinitionExpression::Col(Column {
726 name: String::from("count(*)"),
727 alias: None,
728 table: None,
729 function: Some(Box::new(agg_expr)),
730 })],
731 group_by: Some(GroupByClause {
732 columns: vec![Column::from("aid")],
733 having: None,
734 }),
735 ..Default::default()
736 };
737 assert_eq!(res.unwrap().1, expected_stmt);
738 }
739
740 #[test]
741 fn count_distinct() {
742 let qstring = "SELECT COUNT(DISTINCT vote_id) FROM votes GROUP BY aid;";
743
744 let res = selection(qstring.as_bytes());
745 let agg_expr =
746 FunctionExpression::Count(FunctionArguments::Column(Column::from("vote_id")), true);
747 let expected_stmt = SelectStatement {
748 tables: vec![Table::from("votes")],
749 fields: vec![FieldDefinitionExpression::Col(Column {
750 name: String::from("count(distinct vote_id)"),
751 alias: None,
752 table: None,
753 function: Some(Box::new(agg_expr)),
754 })],
755 group_by: Some(GroupByClause {
756 columns: vec![Column::from("aid")],
757 having: None,
758 }),
759 ..Default::default()
760 };
761 assert_eq!(res.unwrap().1, expected_stmt);
762 }
763
764 #[test]
765 fn count_filter() {
766 let qstring =
767 "SELECT COUNT(CASE WHEN vote_id > 10 THEN vote_id END) FROM votes GROUP BY aid;";
768 let res = selection(qstring.as_bytes());
769
770 let filter_cond = ComparisonOp(ConditionTree {
771 left: Box::new(Base(Field(Column::from("vote_id")))),
772 right: Box::new(Base(Literal(Literal::Integer(10.into())))),
773 operator: Operator::Greater,
774 });
775 let agg_expr = FunctionExpression::Count(
776 FunctionArguments::Conditional(CaseWhenExpression {
777 then_expr: ColumnOrLiteral::Column(Column::from("vote_id")),
778 else_expr: None,
779 condition: filter_cond,
780 }),
781 false,
782 );
783 let expected_stmt = SelectStatement {
784 tables: vec![Table::from("votes")],
785 fields: vec![FieldDefinitionExpression::Col(Column {
786 name: format!("{}", agg_expr),
787 alias: None,
788 table: None,
789 function: Some(Box::new(agg_expr)),
790 })],
791 group_by: Some(GroupByClause {
792 columns: vec![Column::from("aid")],
793 having: None,
794 }),
795 ..Default::default()
796 };
797 assert_eq!(res.unwrap().1, expected_stmt);
798 }
799
800 #[test]
801 fn sum_filter() {
802 let qstring = "SELECT SUM(CASE WHEN sign = 1 THEN vote_id END) FROM votes GROUP BY aid;";
803
804 let res = selection(qstring.as_bytes());
805
806 let filter_cond = ComparisonOp(ConditionTree {
807 left: Box::new(Base(Field(Column::from("sign")))),
808 right: Box::new(Base(Literal(Literal::Integer(1.into())))),
809 operator: Operator::Equal,
810 });
811 let agg_expr = FunctionExpression::Sum(
812 FunctionArguments::Conditional(CaseWhenExpression {
813 then_expr: ColumnOrLiteral::Column(Column::from("vote_id")),
814 else_expr: None,
815 condition: filter_cond,
816 }),
817 false,
818 );
819 let expected_stmt = SelectStatement {
820 tables: vec![Table::from("votes")],
821 fields: vec![FieldDefinitionExpression::Col(Column {
822 name: format!("{}", agg_expr),
823 alias: None,
824 table: None,
825 function: Some(Box::new(agg_expr)),
826 })],
827 group_by: Some(GroupByClause {
828 columns: vec![Column::from("aid")],
829 having: None,
830 }),
831 ..Default::default()
832 };
833 assert_eq!(res.unwrap().1, expected_stmt);
834 }
835
836 #[test]
837 fn sum_filter_else_literal() {
838 let qstring =
839 "SELECT SUM(CASE WHEN sign = 1 THEN vote_id ELSE 6 END) FROM votes GROUP BY aid;";
840
841 let res = selection(qstring.as_bytes());
842
843 let filter_cond = ComparisonOp(ConditionTree {
844 left: Box::new(Base(Field(Column::from("sign")))),
845 right: Box::new(Base(Literal(Literal::Integer(1.into())))),
846 operator: Operator::Equal,
847 });
848 let agg_expr = FunctionExpression::Sum(
849 FunctionArguments::Conditional(CaseWhenExpression {
850 then_expr: ColumnOrLiteral::Column(Column::from("vote_id")),
851 else_expr: Some(ColumnOrLiteral::Literal(Literal::Integer(6))),
852 condition: filter_cond,
853 }),
854 false,
855 );
856 let expected_stmt = SelectStatement {
857 tables: vec![Table::from("votes")],
858 fields: vec![FieldDefinitionExpression::Col(Column {
859 name: format!("{}", agg_expr),
860 alias: None,
861 table: None,
862 function: Some(Box::new(agg_expr)),
863 })],
864 group_by: Some(GroupByClause {
865 columns: vec![Column::from("aid")],
866 having: None,
867 }),
868 ..Default::default()
869 };
870 assert_eq!(res.unwrap().1, expected_stmt);
871 }
872
873 #[test]
874 fn count_filter_lobsters() {
875 let qstring = "SELECT
876 COUNT(CASE WHEN votes.story_id IS NULL AND votes.vote = 0 THEN votes.vote END) as votes
877 FROM votes
878 GROUP BY votes.comment_id;";
879
880 let res = selection(qstring.as_bytes());
881
882 let filter_cond = LogicalOp(ConditionTree {
883 left: Box::new(ComparisonOp(ConditionTree {
884 left: Box::new(Base(Field(Column::from("votes.story_id")))),
885 right: Box::new(Base(Literal(Literal::Null))),
886 operator: Operator::Equal,
887 })),
888 right: Box::new(ComparisonOp(ConditionTree {
889 left: Box::new(Base(Field(Column::from("votes.vote")))),
890 right: Box::new(Base(Literal(Literal::Integer(0)))),
891 operator: Operator::Equal,
892 })),
893 operator: Operator::And,
894 });
895 let agg_expr = FunctionExpression::Count(
896 FunctionArguments::Conditional(CaseWhenExpression {
897 then_expr: ColumnOrLiteral::Column(Column::from("votes.vote")),
898 else_expr: None,
899 condition: filter_cond,
900 }),
901 false,
902 );
903 let expected_stmt = SelectStatement {
904 tables: vec![Table::from("votes")],
905 fields: vec![FieldDefinitionExpression::Col(Column {
906 name: String::from("votes"),
907 alias: Some(String::from("votes")),
908 table: None,
909 function: Some(Box::new(agg_expr)),
910 })],
911 group_by: Some(GroupByClause {
912 columns: vec![Column::from("votes.comment_id")],
913 having: None,
914 }),
915 ..Default::default()
916 };
917 assert_eq!(res.unwrap().1, expected_stmt);
918 }
919
920 #[test]
921 fn moderately_complex_selection() {
922 let qstring = "SELECT * FROM item, author WHERE item.i_a_id = author.a_id AND \
923 item.i_subject = ? ORDER BY item.i_title limit 50;";
924
925 let res = selection(qstring.as_bytes());
926 let expected_where_cond = Some(LogicalOp(ConditionTree {
927 left: Box::new(ComparisonOp(ConditionTree {
928 left: Box::new(Base(Field(Column::from("item.i_a_id")))),
929 right: Box::new(Base(Field(Column::from("author.a_id")))),
930 operator: Operator::Equal,
931 })),
932 right: Box::new(ComparisonOp(ConditionTree {
933 left: Box::new(Base(Field(Column::from("item.i_subject")))),
934 right: Box::new(Base(Literal(Literal::Placeholder))),
935 operator: Operator::Equal,
936 })),
937 operator: Operator::And,
938 }));
939 assert_eq!(
940 res.unwrap().1,
941 SelectStatement {
942 tables: vec![Table::from("item"), Table::from("author")],
943 fields: vec![FieldDefinitionExpression::All],
944 where_clause: expected_where_cond,
945 order: Some(OrderClause {
946 columns: vec![("item.i_title".into(), OrderType::OrderAscending)],
947 }),
948 limit: Some(LimitClause {
949 limit: 50,
950 offset: 0,
951 }),
952 ..Default::default()
953 }
954 );
955 }
956
957 #[test]
958 fn simple_joins() {
959 let qstring = "select paperId from PaperConflict join PCMember using (contactId);";
960
961 let res = selection(qstring.as_bytes());
962 let expected_stmt = SelectStatement {
963 tables: vec![Table::from("PaperConflict")],
964 fields: columns(&["paperId"]),
965 join: vec![JoinClause {
966 operator: JoinOperator::Join,
967 right: JoinRightSide::Table(Table::from("PCMember")),
968 constraint: JoinConstraint::Using(vec![Column::from("contactId")]),
969 }],
970 ..Default::default()
971 };
972 assert_eq!(res.unwrap().1, expected_stmt);
973
974 let qstring = "select PCMember.contactId \
979 from PCMember \
980 join PaperReview on (PCMember.contactId=PaperReview.contactId) \
981 order by contactId;";
982
983 let res = selection(qstring.as_bytes());
984 let ct = ConditionTree {
985 left: Box::new(Base(Field(Column::from("PCMember.contactId")))),
986 right: Box::new(Base(Field(Column::from("PaperReview.contactId")))),
987 operator: Operator::Equal,
988 };
989 let join_cond = ConditionExpression::ComparisonOp(ct);
990 let expected = SelectStatement {
991 tables: vec![Table::from("PCMember")],
992 fields: columns(&["PCMember.contactId"]),
993 join: vec![JoinClause {
994 operator: JoinOperator::Join,
995 right: JoinRightSide::Table(Table::from("PaperReview")),
996 constraint: JoinConstraint::On(join_cond),
997 }],
998 order: Some(OrderClause {
999 columns: vec![("contactId".into(), OrderType::OrderAscending)],
1000 }),
1001 ..Default::default()
1002 };
1003 assert_eq!(res.unwrap().1, expected);
1004
1005 let qstring = "select PCMember.contactId \
1007 from PCMember \
1008 join PaperReview on PCMember.contactId=PaperReview.contactId \
1009 order by contactId;";
1010 let res = selection(qstring.as_bytes());
1011 assert_eq!(res.unwrap().1, expected);
1012 }
1013
1014 #[test]
1015 fn multi_join() {
1016 let qstring = "select PCMember.contactId, ChairAssistant.contactId, \
1025 Chair.contactId from ContactInfo left join PaperReview using (contactId) \
1026 left join PaperConflict using (contactId) left join PCMember using \
1027 (contactId) left join ChairAssistant using (contactId) left join Chair \
1028 using (contactId) where ContactInfo.contactId=?;";
1029
1030 let res = selection(qstring.as_bytes());
1031 let ct = ConditionTree {
1032 left: Box::new(Base(Field(Column::from("ContactInfo.contactId")))),
1033 right: Box::new(Base(Literal(Literal::Placeholder))),
1034 operator: Operator::Equal,
1035 };
1036 let expected_where_cond = Some(ComparisonOp(ct));
1037 let mkjoin = |tbl: &str, col: &str| -> JoinClause {
1038 JoinClause {
1039 operator: JoinOperator::LeftJoin,
1040 right: JoinRightSide::Table(Table::from(tbl)),
1041 constraint: JoinConstraint::Using(vec![Column::from(col)]),
1042 }
1043 };
1044 assert_eq!(
1045 res.unwrap().1,
1046 SelectStatement {
1047 tables: vec![Table::from("ContactInfo")],
1048 fields: columns(&[
1049 "PCMember.contactId",
1050 "ChairAssistant.contactId",
1051 "Chair.contactId"
1052 ]),
1053 join: vec![
1054 mkjoin("PaperReview", "contactId"),
1055 mkjoin("PaperConflict", "contactId"),
1056 mkjoin("PCMember", "contactId"),
1057 mkjoin("ChairAssistant", "contactId"),
1058 mkjoin("Chair", "contactId"),
1059 ],
1060 where_clause: expected_where_cond,
1061 ..Default::default()
1062 }
1063 );
1064 }
1065
1066 #[test]
1067 fn nested_select() {
1068 let qstr = "SELECT ol_i_id FROM orders, order_line \
1069 WHERE orders.o_c_id IN (SELECT o_c_id FROM orders, order_line \
1070 WHERE orders.o_id = order_line.ol_o_id);";
1071
1072 let res = selection(qstr.as_bytes());
1073 let inner_where_clause = ComparisonOp(ConditionTree {
1074 left: Box::new(Base(Field(Column::from("orders.o_id")))),
1075 right: Box::new(Base(Field(Column::from("order_line.ol_o_id")))),
1076 operator: Operator::Equal,
1077 });
1078
1079 let inner_select = SelectStatement {
1080 tables: vec![Table::from("orders"), Table::from("order_line")],
1081 fields: columns(&["o_c_id"]),
1082 where_clause: Some(inner_where_clause),
1083 ..Default::default()
1084 };
1085
1086 let outer_where_clause = ComparisonOp(ConditionTree {
1087 left: Box::new(Base(Field(Column::from("orders.o_c_id")))),
1088 right: Box::new(Base(NestedSelect(Box::new(inner_select)))),
1089 operator: Operator::In,
1090 });
1091
1092 let outer_select = SelectStatement {
1093 tables: vec![Table::from("orders"), Table::from("order_line")],
1094 fields: columns(&["ol_i_id"]),
1095 where_clause: Some(outer_where_clause),
1096 ..Default::default()
1097 };
1098
1099 assert_eq!(res.unwrap().1, outer_select);
1100 }
1101
1102 #[test]
1103 fn recursive_nested_select() {
1104 let qstr = "SELECT ol_i_id FROM orders, order_line WHERE orders.o_c_id \
1105 IN (SELECT o_c_id FROM orders, order_line \
1106 WHERE orders.o_id = order_line.ol_o_id \
1107 AND orders.o_id > (SELECT MAX(o_id) FROM orders));";
1108
1109 let res = selection(qstr.as_bytes());
1110
1111 let agg_expr = FunctionExpression::Max(FunctionArguments::Column(Column::from("o_id")));
1112 let recursive_select = SelectStatement {
1113 tables: vec![Table::from("orders")],
1114 fields: vec![FieldDefinitionExpression::Col(Column {
1115 name: String::from("max(o_id)"),
1116 alias: None,
1117 table: None,
1118 function: Some(Box::new(agg_expr)),
1119 })],
1120 ..Default::default()
1121 };
1122
1123 let cop1 = ComparisonOp(ConditionTree {
1124 left: Box::new(Base(Field(Column::from("orders.o_id")))),
1125 right: Box::new(Base(Field(Column::from("order_line.ol_o_id")))),
1126 operator: Operator::Equal,
1127 });
1128
1129 let cop2 = ComparisonOp(ConditionTree {
1130 left: Box::new(Base(Field(Column::from("orders.o_id")))),
1131 right: Box::new(Base(NestedSelect(Box::new(recursive_select)))),
1132 operator: Operator::Greater,
1133 });
1134
1135 let inner_where_clause = LogicalOp(ConditionTree {
1136 left: Box::new(cop1),
1137 right: Box::new(cop2),
1138 operator: Operator::And,
1139 });
1140
1141 let inner_select = SelectStatement {
1142 tables: vec![Table::from("orders"), Table::from("order_line")],
1143 fields: columns(&["o_c_id"]),
1144 where_clause: Some(inner_where_clause),
1145 ..Default::default()
1146 };
1147
1148 let outer_where_clause = ComparisonOp(ConditionTree {
1149 left: Box::new(Base(Field(Column::from("orders.o_c_id")))),
1150 right: Box::new(Base(NestedSelect(Box::new(inner_select)))),
1151 operator: Operator::In,
1152 });
1153
1154 let outer_select = SelectStatement {
1155 tables: vec![Table::from("orders"), Table::from("order_line")],
1156 fields: columns(&["ol_i_id"]),
1157 where_clause: Some(outer_where_clause),
1158 ..Default::default()
1159 };
1160
1161 assert_eq!(res.unwrap().1, outer_select);
1162 }
1163
1164 #[test]
1165 fn join_against_nested_select() {
1166 let t0 = b"(SELECT ol_i_id FROM order_line)";
1167 let t1 = b"(SELECT ol_i_id FROM order_line) AS ids";
1168
1169 assert!(join_rhs(t0).is_ok());
1170 assert!(join_rhs(t1).is_ok());
1171
1172 let t0 = b"JOIN (SELECT ol_i_id FROM order_line) ON (orders.o_id = ol_i_id)";
1173 let t1 = b"JOIN (SELECT ol_i_id FROM order_line) AS ids ON (orders.o_id = ids.ol_i_id)";
1174
1175 assert!(join_clause(t0).is_ok());
1176 assert!(join_clause(t1).is_ok());
1177
1178 let qstr_with_alias = "SELECT o_id, ol_i_id FROM orders JOIN \
1179 (SELECT ol_i_id FROM order_line) AS ids \
1180 ON (orders.o_id = ids.ol_i_id);";
1181 let res = selection(qstr_with_alias.as_bytes());
1182
1183 let inner_select = SelectStatement {
1185 tables: vec![Table::from("order_line")],
1186 fields: columns(&["ol_i_id"]),
1187 ..Default::default()
1188 };
1189
1190 let outer_select = SelectStatement {
1191 tables: vec![Table::from("orders")],
1192 fields: columns(&["o_id", "ol_i_id"]),
1193 join: vec![JoinClause {
1194 operator: JoinOperator::Join,
1195 right: JoinRightSide::NestedSelect(Box::new(inner_select), Some("ids".into())),
1196 constraint: JoinConstraint::On(ComparisonOp(ConditionTree {
1197 operator: Operator::Equal,
1198 left: Box::new(Base(Field(Column::from("orders.o_id")))),
1199 right: Box::new(Base(Field(Column::from("ids.ol_i_id")))),
1200 })),
1201 }],
1202 ..Default::default()
1203 };
1204
1205 assert_eq!(res.unwrap().1, outer_select);
1206 }
1207
1208 #[test]
1209 fn project_arithmetic_expressions() {
1210 use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator};
1211
1212 let qstr = "SELECT MAX(o_id)-3333 FROM orders;";
1213 let res = selection(qstr.as_bytes());
1214
1215 let expected = SelectStatement {
1216 tables: vec![Table::from("orders")],
1217 fields: vec![FieldDefinitionExpression::Value(
1218 FieldValueExpression::Arithmetic(ArithmeticExpression {
1219 alias: None,
1220 op: ArithmeticOperator::Subtract,
1221 left: ArithmeticBase::Column(Column {
1222 name: String::from("max(o_id)"),
1223 alias: None,
1224 table: None,
1225 function: Some(Box::new(FunctionExpression::Max(
1226 FunctionArguments::Column("o_id".into()),
1227 ))),
1228 }),
1229 right: ArithmeticBase::Scalar(3333.into()),
1230 }),
1231 )],
1232 ..Default::default()
1233 };
1234
1235 assert_eq!(res.unwrap().1, expected);
1236 }
1237
1238 #[test]
1239 fn project_arithmetic_expressions_with_aliases() {
1240 use arithmetic::{ArithmeticBase, ArithmeticExpression, ArithmeticOperator};
1241
1242 let qstr = "SELECT max(o_id) * 2 as double_max FROM orders;";
1243 let res = selection(qstr.as_bytes());
1244
1245 let expected = SelectStatement {
1246 tables: vec![Table::from("orders")],
1247 fields: vec![FieldDefinitionExpression::Value(
1248 FieldValueExpression::Arithmetic(ArithmeticExpression {
1249 alias: Some(String::from("double_max")),
1250 op: ArithmeticOperator::Multiply,
1251 left: ArithmeticBase::Column(Column {
1252 name: String::from("max(o_id)"),
1253 alias: None,
1254 table: None,
1255 function: Some(Box::new(FunctionExpression::Max(
1256 FunctionArguments::Column("o_id".into()),
1257 ))),
1258 }),
1259 right: ArithmeticBase::Scalar(2.into()),
1260 }),
1261 )],
1262 ..Default::default()
1263 };
1264
1265 assert_eq!(res.unwrap().1, expected);
1266 }
1267
1268 #[test]
1269 fn where_in_clause() {
1270 let qstr = "SELECT `auth_permission`.`content_type_id`, `auth_permission`.`codename`
1271 FROM `auth_permission`
1272 JOIN `django_content_type`
1273 ON ( `auth_permission`.`content_type_id` = `django_content_type`.`id` )
1274 WHERE `auth_permission`.`content_type_id` IN (0);";
1275 let res = selection(qstr.as_bytes());
1276
1277 let expected_where_clause = Some(ComparisonOp(ConditionTree {
1278 left: Box::new(Base(Field(Column::from("auth_permission.content_type_id")))),
1279 right: Box::new(Base(LiteralList(vec![0.into()]))),
1280 operator: Operator::In,
1281 }));
1282
1283 let expected = SelectStatement {
1284 tables: vec![Table::from("auth_permission")],
1285 fields: vec![
1286 FieldDefinitionExpression::Col(Column::from("auth_permission.content_type_id")),
1287 FieldDefinitionExpression::Col(Column::from("auth_permission.codename")),
1288 ],
1289 join: vec![JoinClause {
1290 operator: JoinOperator::Join,
1291 right: JoinRightSide::Table(Table::from("django_content_type")),
1292 constraint: JoinConstraint::On(ComparisonOp(ConditionTree {
1293 operator: Operator::Equal,
1294 left: Box::new(Base(Field(Column::from("auth_permission.content_type_id")))),
1295 right: Box::new(Base(Field(Column::from("django_content_type.id")))),
1296 })),
1297 }],
1298 where_clause: expected_where_clause,
1299 ..Default::default()
1300 };
1301
1302 assert_eq!(res.unwrap().1, expected);
1303 }
1304}