1use nom::character::complete::{multispace0, multispace1};
2use std::fmt;
3use std::str;
4
5use common::statement_terminator;
6use nom::branch::alt;
7use nom::bytes::complete::{tag, tag_no_case};
8use nom::combinator::{map, opt};
9use nom::multi::many1;
10use nom::sequence::{delimited, preceded, tuple};
11use nom::IResult;
12use order::{order_clause, OrderClause};
13use select::{limit_clause, nested_selection, LimitClause, SelectStatement};
14
15#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)]
16pub enum CompoundSelectOperator {
17 Union,
18 DistinctUnion,
19 Intersect,
20 Except,
21}
22
23impl fmt::Display for CompoundSelectOperator {
24 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
25 match *self {
26 CompoundSelectOperator::Union => write!(f, "UNION"),
27 CompoundSelectOperator::DistinctUnion => write!(f, "UNION DISTINCT"),
28 CompoundSelectOperator::Intersect => write!(f, "INTERSECT"),
29 CompoundSelectOperator::Except => write!(f, "EXCEPT"),
30 }
31 }
32}
33
34#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)]
35pub struct CompoundSelectStatement {
36 pub selects: Vec<(Option<CompoundSelectOperator>, SelectStatement)>,
37 pub order: Option<OrderClause>,
38 pub limit: Option<LimitClause>,
39}
40
41impl fmt::Display for CompoundSelectStatement {
42 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
43 for (ref op, ref sel) in &self.selects {
44 if op.is_some() {
45 write!(f, " {}", op.as_ref().unwrap())?;
46 }
47 write!(f, " {}", sel)?;
48 }
49 if self.order.is_some() {
50 write!(f, " {}", self.order.as_ref().unwrap())?;
51 }
52 if self.limit.is_some() {
53 write!(f, " {}", self.order.as_ref().unwrap())?;
54 }
55 Ok(())
56 }
57}
58
59fn compound_op(i: &[u8]) -> IResult<&[u8], CompoundSelectOperator> {
61 alt((
62 map(
63 preceded(
64 tag_no_case("union"),
65 opt(preceded(
66 multispace1,
67 alt((
68 map(tag_no_case("all"), |_| false),
69 map(tag_no_case("distinct"), |_| true),
70 )),
71 )),
72 ),
73 |distinct| match distinct {
74 None => CompoundSelectOperator::DistinctUnion,
76 Some(d) => {
77 if d {
78 CompoundSelectOperator::DistinctUnion
79 } else {
80 CompoundSelectOperator::Union
81 }
82 }
83 },
84 ),
85 map(tag_no_case("intersect"), |_| {
86 CompoundSelectOperator::Intersect
87 }),
88 map(tag_no_case("except"), |_| CompoundSelectOperator::Except),
89 ))(i)
90}
91
92fn other_selects(i: &[u8]) -> IResult<&[u8], (Option<CompoundSelectOperator>, SelectStatement)> {
93 let (remaining_input, (_, op, _, _, _, select, _, _)) = tuple((
94 multispace0,
95 compound_op,
96 multispace1,
97 opt(tag("(")),
98 multispace0,
99 nested_selection,
100 multispace0,
101 opt(tag(")")),
102 ))(i)?;
103
104 Ok((remaining_input, (Some(op), select)))
105}
106
107pub fn compound_selection(i: &[u8]) -> IResult<&[u8], CompoundSelectStatement> {
109 let (remaining_input, (first_select, other_selects, _, order, limit, _)) = tuple((
110 delimited(opt(tag("(")), nested_selection, opt(tag(")"))),
111 many1(other_selects),
112 multispace0,
113 opt(order_clause),
114 opt(limit_clause),
115 statement_terminator,
116 ))(i)?;
117
118 let mut selects = vec![(None, first_select)];
119 selects.extend(other_selects);
120
121 Ok((
122 remaining_input,
123 CompoundSelectStatement {
124 selects,
125 order,
126 limit,
127 },
128 ))
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use column::Column;
135 use common::{FieldDefinitionExpression, FieldValueExpression, Literal};
136 use table::Table;
137
138 #[test]
139 fn union() {
140 let qstr = "SELECT id, 1 FROM Vote UNION SELECT id, stars from Rating;";
141 let qstr2 = "(SELECT id, 1 FROM Vote) UNION (SELECT id, stars from Rating);";
142 let res = compound_selection(qstr.as_bytes());
143 let res2 = compound_selection(qstr2.as_bytes());
144
145 let first_select = SelectStatement {
146 tables: vec![Table::from("Vote")],
147 fields: vec![
148 FieldDefinitionExpression::Col(Column::from("id")),
149 FieldDefinitionExpression::Value(FieldValueExpression::Literal(
150 Literal::Integer(1).into(),
151 )),
152 ],
153 ..Default::default()
154 };
155 let second_select = SelectStatement {
156 tables: vec![Table::from("Rating")],
157 fields: vec![
158 FieldDefinitionExpression::Col(Column::from("id")),
159 FieldDefinitionExpression::Col(Column::from("stars")),
160 ],
161 ..Default::default()
162 };
163 let expected = CompoundSelectStatement {
164 selects: vec![
165 (None, first_select),
166 (Some(CompoundSelectOperator::DistinctUnion), second_select),
167 ],
168 order: None,
169 limit: None,
170 };
171
172 assert_eq!(res.unwrap().1, expected);
173 assert_eq!(res2.unwrap().1, expected);
174 }
175
176 #[test]
177 fn multi_union() {
178 let qstr = "SELECT id, 1 FROM Vote \
179 UNION SELECT id, stars from Rating \
180 UNION DISTINCT SELECT 42, 5 FROM Vote;";
181 let res = compound_selection(qstr.as_bytes());
182
183 let first_select = SelectStatement {
184 tables: vec![Table::from("Vote")],
185 fields: vec![
186 FieldDefinitionExpression::Col(Column::from("id")),
187 FieldDefinitionExpression::Value(FieldValueExpression::Literal(
188 Literal::Integer(1).into(),
189 )),
190 ],
191 ..Default::default()
192 };
193 let second_select = SelectStatement {
194 tables: vec![Table::from("Rating")],
195 fields: vec![
196 FieldDefinitionExpression::Col(Column::from("id")),
197 FieldDefinitionExpression::Col(Column::from("stars")),
198 ],
199 ..Default::default()
200 };
201 let third_select = SelectStatement {
202 tables: vec![Table::from("Vote")],
203 fields: vec![
204 FieldDefinitionExpression::Value(FieldValueExpression::Literal(
205 Literal::Integer(42).into(),
206 )),
207 FieldDefinitionExpression::Value(FieldValueExpression::Literal(
208 Literal::Integer(5).into(),
209 )),
210 ],
211 ..Default::default()
212 };
213
214 let expected = CompoundSelectStatement {
215 selects: vec![
216 (None, first_select),
217 (Some(CompoundSelectOperator::DistinctUnion), second_select),
218 (Some(CompoundSelectOperator::DistinctUnion), third_select),
219 ],
220 order: None,
221 limit: None,
222 };
223
224 assert_eq!(res.unwrap().1, expected);
225 }
226
227 #[test]
228 fn union_all() {
229 let qstr = "SELECT id, 1 FROM Vote UNION ALL SELECT id, stars from Rating;";
230 let res = compound_selection(qstr.as_bytes());
231
232 let first_select = SelectStatement {
233 tables: vec![Table::from("Vote")],
234 fields: vec![
235 FieldDefinitionExpression::Col(Column::from("id")),
236 FieldDefinitionExpression::Value(FieldValueExpression::Literal(
237 Literal::Integer(1).into(),
238 )),
239 ],
240 ..Default::default()
241 };
242 let second_select = SelectStatement {
243 tables: vec![Table::from("Rating")],
244 fields: vec![
245 FieldDefinitionExpression::Col(Column::from("id")),
246 FieldDefinitionExpression::Col(Column::from("stars")),
247 ],
248 ..Default::default()
249 };
250 let expected = CompoundSelectStatement {
251 selects: vec![
252 (None, first_select),
253 (Some(CompoundSelectOperator::Union), second_select),
254 ],
255 order: None,
256 limit: None,
257 };
258
259 assert_eq!(res.unwrap().1, expected);
260 }
261}