nom_sql/
compound_select.rs

1use nom::character::complete::{multispace0, multispace1};
2use std::fmt;
3use std::str;
4
5use common::statement_terminator;
6use nom::branch::alt;
7use nom::bytes::complete::{tag, tag_no_case};
8use nom::combinator::{map, opt};
9use nom::multi::many1;
10use nom::sequence::{delimited, preceded, tuple};
11use nom::IResult;
12use order::{order_clause, OrderClause};
13use select::{limit_clause, nested_selection, LimitClause, SelectStatement};
14
15#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)]
16pub enum CompoundSelectOperator {
17    Union,
18    DistinctUnion,
19    Intersect,
20    Except,
21}
22
23impl fmt::Display for CompoundSelectOperator {
24    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
25        match *self {
26            CompoundSelectOperator::Union => write!(f, "UNION"),
27            CompoundSelectOperator::DistinctUnion => write!(f, "UNION DISTINCT"),
28            CompoundSelectOperator::Intersect => write!(f, "INTERSECT"),
29            CompoundSelectOperator::Except => write!(f, "EXCEPT"),
30        }
31    }
32}
33
34#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)]
35pub struct CompoundSelectStatement {
36    pub selects: Vec<(Option<CompoundSelectOperator>, SelectStatement)>,
37    pub order: Option<OrderClause>,
38    pub limit: Option<LimitClause>,
39}
40
41impl fmt::Display for CompoundSelectStatement {
42    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
43        for (ref op, ref sel) in &self.selects {
44            if op.is_some() {
45                write!(f, " {}", op.as_ref().unwrap())?;
46            }
47            write!(f, " {}", sel)?;
48        }
49        if self.order.is_some() {
50            write!(f, " {}", self.order.as_ref().unwrap())?;
51        }
52        if self.limit.is_some() {
53            write!(f, " {}", self.order.as_ref().unwrap())?;
54        }
55        Ok(())
56    }
57}
58
59// Parse compound operator
60fn compound_op(i: &[u8]) -> IResult<&[u8], CompoundSelectOperator> {
61    alt((
62        map(
63            preceded(
64                tag_no_case("union"),
65                opt(preceded(
66                    multispace1,
67                    alt((
68                        map(tag_no_case("all"), |_| false),
69                        map(tag_no_case("distinct"), |_| true),
70                    )),
71                )),
72            ),
73            |distinct| match distinct {
74                // DISTINCT is the default in both MySQL and SQLite
75                None => CompoundSelectOperator::DistinctUnion,
76                Some(d) => {
77                    if d {
78                        CompoundSelectOperator::DistinctUnion
79                    } else {
80                        CompoundSelectOperator::Union
81                    }
82                }
83            },
84        ),
85        map(tag_no_case("intersect"), |_| {
86            CompoundSelectOperator::Intersect
87        }),
88        map(tag_no_case("except"), |_| CompoundSelectOperator::Except),
89    ))(i)
90}
91
92fn other_selects(i: &[u8]) -> IResult<&[u8], (Option<CompoundSelectOperator>, SelectStatement)> {
93    let (remaining_input, (_, op, _, _, _, select, _, _)) = tuple((
94        multispace0,
95        compound_op,
96        multispace1,
97        opt(tag("(")),
98        multispace0,
99        nested_selection,
100        multispace0,
101        opt(tag(")")),
102    ))(i)?;
103
104    Ok((remaining_input, (Some(op), select)))
105}
106
107// Parse compound selection
108pub fn compound_selection(i: &[u8]) -> IResult<&[u8], CompoundSelectStatement> {
109    let (remaining_input, (first_select, other_selects, _, order, limit, _)) = tuple((
110        delimited(opt(tag("(")), nested_selection, opt(tag(")"))),
111        many1(other_selects),
112        multispace0,
113        opt(order_clause),
114        opt(limit_clause),
115        statement_terminator,
116    ))(i)?;
117
118    let mut selects = vec![(None, first_select)];
119    selects.extend(other_selects);
120
121    Ok((
122        remaining_input,
123        CompoundSelectStatement {
124            selects,
125            order,
126            limit,
127        },
128    ))
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use column::Column;
135    use common::{FieldDefinitionExpression, FieldValueExpression, Literal};
136    use table::Table;
137
138    #[test]
139    fn union() {
140        let qstr = "SELECT id, 1 FROM Vote UNION SELECT id, stars from Rating;";
141        let qstr2 = "(SELECT id, 1 FROM Vote) UNION (SELECT id, stars from Rating);";
142        let res = compound_selection(qstr.as_bytes());
143        let res2 = compound_selection(qstr2.as_bytes());
144
145        let first_select = SelectStatement {
146            tables: vec![Table::from("Vote")],
147            fields: vec![
148                FieldDefinitionExpression::Col(Column::from("id")),
149                FieldDefinitionExpression::Value(FieldValueExpression::Literal(
150                    Literal::Integer(1).into(),
151                )),
152            ],
153            ..Default::default()
154        };
155        let second_select = SelectStatement {
156            tables: vec![Table::from("Rating")],
157            fields: vec![
158                FieldDefinitionExpression::Col(Column::from("id")),
159                FieldDefinitionExpression::Col(Column::from("stars")),
160            ],
161            ..Default::default()
162        };
163        let expected = CompoundSelectStatement {
164            selects: vec![
165                (None, first_select),
166                (Some(CompoundSelectOperator::DistinctUnion), second_select),
167            ],
168            order: None,
169            limit: None,
170        };
171
172        assert_eq!(res.unwrap().1, expected);
173        assert_eq!(res2.unwrap().1, expected);
174    }
175
176    #[test]
177    fn multi_union() {
178        let qstr = "SELECT id, 1 FROM Vote \
179                    UNION SELECT id, stars from Rating \
180                    UNION DISTINCT SELECT 42, 5 FROM Vote;";
181        let res = compound_selection(qstr.as_bytes());
182
183        let first_select = SelectStatement {
184            tables: vec![Table::from("Vote")],
185            fields: vec![
186                FieldDefinitionExpression::Col(Column::from("id")),
187                FieldDefinitionExpression::Value(FieldValueExpression::Literal(
188                    Literal::Integer(1).into(),
189                )),
190            ],
191            ..Default::default()
192        };
193        let second_select = SelectStatement {
194            tables: vec![Table::from("Rating")],
195            fields: vec![
196                FieldDefinitionExpression::Col(Column::from("id")),
197                FieldDefinitionExpression::Col(Column::from("stars")),
198            ],
199            ..Default::default()
200        };
201        let third_select = SelectStatement {
202            tables: vec![Table::from("Vote")],
203            fields: vec![
204                FieldDefinitionExpression::Value(FieldValueExpression::Literal(
205                    Literal::Integer(42).into(),
206                )),
207                FieldDefinitionExpression::Value(FieldValueExpression::Literal(
208                    Literal::Integer(5).into(),
209                )),
210            ],
211            ..Default::default()
212        };
213
214        let expected = CompoundSelectStatement {
215            selects: vec![
216                (None, first_select),
217                (Some(CompoundSelectOperator::DistinctUnion), second_select),
218                (Some(CompoundSelectOperator::DistinctUnion), third_select),
219            ],
220            order: None,
221            limit: None,
222        };
223
224        assert_eq!(res.unwrap().1, expected);
225    }
226
227    #[test]
228    fn union_all() {
229        let qstr = "SELECT id, 1 FROM Vote UNION ALL SELECT id, stars from Rating;";
230        let res = compound_selection(qstr.as_bytes());
231
232        let first_select = SelectStatement {
233            tables: vec![Table::from("Vote")],
234            fields: vec![
235                FieldDefinitionExpression::Col(Column::from("id")),
236                FieldDefinitionExpression::Value(FieldValueExpression::Literal(
237                    Literal::Integer(1).into(),
238                )),
239            ],
240            ..Default::default()
241        };
242        let second_select = SelectStatement {
243            tables: vec![Table::from("Rating")],
244            fields: vec![
245                FieldDefinitionExpression::Col(Column::from("id")),
246                FieldDefinitionExpression::Col(Column::from("stars")),
247            ],
248            ..Default::default()
249        };
250        let expected = CompoundSelectStatement {
251            selects: vec![
252                (None, first_select),
253                (Some(CompoundSelectOperator::Union), second_select),
254            ],
255            order: None,
256            limit: None,
257        };
258
259        assert_eq!(res.unwrap().1, expected);
260    }
261}