nom_sql/
arithmetic.rs

1use nom::character::complete::{multispace0, multispace1};
2use std::{fmt, str};
3
4use column::Column;
5use common::{
6    as_alias, column_identifier_no_alias, integer_literal, type_identifier, Literal, SqlType,
7};
8use nom::branch::alt;
9use nom::bytes::complete::{tag, tag_no_case};
10use nom::combinator::{map, opt};
11use nom::sequence::{terminated, tuple};
12use nom::IResult;
13
14#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
15pub enum ArithmeticOperator {
16    Add,
17    Subtract,
18    Multiply,
19    Divide,
20}
21
22#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
23pub enum ArithmeticBase {
24    Column(Column),
25    Scalar(Literal),
26}
27
28#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
29pub struct ArithmeticExpression {
30    pub op: ArithmeticOperator,
31    pub left: ArithmeticBase,
32    pub right: ArithmeticBase,
33    pub alias: Option<String>,
34}
35
36impl ArithmeticExpression {
37    pub fn new(
38        op: ArithmeticOperator,
39        left: ArithmeticBase,
40        right: ArithmeticBase,
41        alias: Option<String>,
42    ) -> Self {
43        Self {
44            op,
45            left,
46            right,
47            alias,
48        }
49    }
50}
51
52impl fmt::Display for ArithmeticOperator {
53    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54        match *self {
55            ArithmeticOperator::Add => write!(f, "+"),
56            ArithmeticOperator::Subtract => write!(f, "-"),
57            ArithmeticOperator::Multiply => write!(f, "*"),
58            ArithmeticOperator::Divide => write!(f, "/"),
59        }
60    }
61}
62
63impl fmt::Display for ArithmeticBase {
64    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65        match *self {
66            ArithmeticBase::Column(ref col) => write!(f, "{}", col),
67            ArithmeticBase::Scalar(ref lit) => write!(f, "{}", lit.to_string()),
68        }
69    }
70}
71
72impl fmt::Display for ArithmeticExpression {
73    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
74        match self.alias {
75            Some(ref alias) => write!(f, "{} {} {} AS {}", self.left, self.op, self.right, alias),
76            None => write!(f, "{} {} {}", self.left, self.op, self.right),
77        }
78    }
79}
80
81fn arithmetic_cast_helper(i: &[u8]) -> IResult<&[u8], (ArithmeticBase, Option<SqlType>)> {
82    let (remaining_input, (_, _, _, _, a_base, _, _, _, _sign, sql_type, _, _)) = tuple((
83        tag_no_case("cast"),
84        multispace0,
85        tag("("),
86        multispace0,
87        // TODO(malte): should be arbitrary expr
88        arithmetic_base,
89        multispace1,
90        tag_no_case("as"),
91        multispace1,
92        opt(terminated(tag_no_case("signed"), multispace1)),
93        type_identifier,
94        multispace0,
95        tag(")"),
96    ))(i)?;
97
98    Ok((remaining_input, (a_base, Some(sql_type))))
99}
100
101pub fn arithmetic_cast(i: &[u8]) -> IResult<&[u8], (ArithmeticBase, Option<SqlType>)> {
102    alt((arithmetic_cast_helper, map(arithmetic_base, |v| (v, None))))(i)
103}
104
105// Parse standard math operators.
106// TODO(malte): this doesn't currently observe operator precedence.
107pub fn arithmetic_operator(i: &[u8]) -> IResult<&[u8], ArithmeticOperator> {
108    alt((
109        map(tag("+"), |_| ArithmeticOperator::Add),
110        map(tag("-"), |_| ArithmeticOperator::Subtract),
111        map(tag("*"), |_| ArithmeticOperator::Multiply),
112        map(tag("/"), |_| ArithmeticOperator::Divide),
113    ))(i)
114}
115
116// Base case for nested arithmetic expressions: column name or literal.
117pub fn arithmetic_base(i: &[u8]) -> IResult<&[u8], ArithmeticBase> {
118    alt((
119        map(integer_literal, |il| ArithmeticBase::Scalar(il)),
120        map(column_identifier_no_alias, |ci| ArithmeticBase::Column(ci)),
121    ))(i)
122}
123
124// Parse simple arithmetic expressions combining literals, and columns and literals.
125// TODO(malte): this doesn't currently support nested expressions.
126pub fn arithmetic_expression(i: &[u8]) -> IResult<&[u8], ArithmeticExpression> {
127    let (remaining_input, (left, _, op, _, right, opt_alias)) = tuple((
128        arithmetic_cast,
129        multispace0,
130        arithmetic_operator,
131        multispace0,
132        arithmetic_cast,
133        opt(as_alias),
134    ))(i)?;
135
136    let alias = match opt_alias {
137        None => None,
138        Some(a) => Some(String::from(a)),
139    };
140
141    Ok((
142        remaining_input,
143        ArithmeticExpression {
144            left: left.0,
145            right: right.0,
146            op,
147            alias,
148        },
149    ))
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn it_parses_arithmetic_expressions() {
158        use super::ArithmeticBase::Column as ABColumn;
159        use super::ArithmeticBase::Scalar;
160        use super::ArithmeticOperator::*;
161        use column::{FunctionArguments, FunctionExpression};
162
163        let lit_ae = [
164            "5 + 42",
165            "5+42",
166            "5 * 42",
167            "5 - 42",
168            "5 / 42",
169            "2 * 10 AS twenty ",
170        ];
171
172        // N.B. trailing space in "5 + foo " is required because `sql_identifier`'s keyword
173        // detection requires a follow-up character (in practice, there always is one because we
174        // use semicolon-terminated queries).
175        let col_lit_ae = [
176            "foo+5",
177            "foo + 5",
178            "5 + foo ",
179            "foo * bar AS foobar",
180            "MAX(foo)-3333",
181        ];
182
183        let expected_lit_ae = [
184            ArithmeticExpression::new(Add, Scalar(5.into()), Scalar(42.into()), None),
185            ArithmeticExpression::new(Add, Scalar(5.into()), Scalar(42.into()), None),
186            ArithmeticExpression::new(Multiply, Scalar(5.into()), Scalar(42.into()), None),
187            ArithmeticExpression::new(Subtract, Scalar(5.into()), Scalar(42.into()), None),
188            ArithmeticExpression::new(Divide, Scalar(5.into()), Scalar(42.into()), None),
189            ArithmeticExpression::new(
190                Multiply,
191                Scalar(2.into()),
192                Scalar(10.into()),
193                Some(String::from("twenty")),
194            ),
195        ];
196        let expected_col_lit_ae = [
197            ArithmeticExpression::new(Add, ABColumn("foo".into()), Scalar(5.into()), None),
198            ArithmeticExpression::new(Add, ABColumn("foo".into()), Scalar(5.into()), None),
199            ArithmeticExpression::new(Add, Scalar(5.into()), ABColumn("foo".into()), None),
200            ArithmeticExpression::new(
201                Multiply,
202                ABColumn("foo".into()),
203                ABColumn("bar".into()),
204                Some(String::from("foobar")),
205            ),
206            ArithmeticExpression::new(
207                Subtract,
208                ABColumn(Column {
209                    name: String::from("max(foo)"),
210                    alias: None,
211                    table: None,
212                    function: Some(Box::new(FunctionExpression::Max(
213                        FunctionArguments::Column("foo".into()),
214                    ))),
215                }),
216                Scalar(3333.into()),
217                None,
218            ),
219        ];
220
221        for (i, e) in lit_ae.iter().enumerate() {
222            let res = arithmetic_expression(e.as_bytes());
223            assert!(res.is_ok());
224            assert_eq!(res.unwrap().1, expected_lit_ae[i]);
225        }
226
227        for (i, e) in col_lit_ae.iter().enumerate() {
228            let res = arithmetic_expression(e.as_bytes());
229            assert!(res.is_ok());
230            assert_eq!(res.unwrap().1, expected_col_lit_ae[i]);
231        }
232    }
233
234    #[test]
235    fn it_displays_arithmetic_expressions() {
236        use super::ArithmeticBase::Column as ABColumn;
237        use super::ArithmeticBase::Scalar;
238        use super::ArithmeticOperator::*;
239
240        let expressions = [
241            ArithmeticExpression::new(Add, ABColumn("foo".into()), Scalar(5.into()), None),
242            ArithmeticExpression::new(Subtract, Scalar(5.into()), ABColumn("foo".into()), None),
243            ArithmeticExpression::new(
244                Multiply,
245                ABColumn("foo".into()),
246                ABColumn("bar".into()),
247                None,
248            ),
249            ArithmeticExpression::new(Divide, Scalar(10.into()), Scalar(2.into()), None),
250            ArithmeticExpression::new(
251                Add,
252                Scalar(10.into()),
253                Scalar(2.into()),
254                Some(String::from("bob")),
255            ),
256        ];
257
258        let expected_strings = ["foo + 5", "5 - foo", "foo * bar", "10 / 2", "10 + 2 AS bob"];
259        for (i, e) in expressions.iter().enumerate() {
260            assert_eq!(expected_strings[i], format!("{}", e));
261        }
262    }
263
264    #[test]
265    fn it_parses_arithmetic_casts() {
266        use super::ArithmeticBase::Column as ABColumn;
267        use super::ArithmeticBase::Scalar;
268        use super::ArithmeticOperator::*;
269
270        let exprs = [
271            "CAST(`t`.`foo` AS signed int) + CAST(`t`.`bar` AS signed int) ",
272            "CAST(5 AS bigint) - foo ",
273            "CAST(5 AS bigint) - foo AS 5_minus_foo",
274        ];
275
276        // XXX(malte): currently discards the cast and type information!
277        let expected = [
278            ArithmeticExpression::new(
279                Add,
280                ABColumn(Column::from("t.foo")),
281                ABColumn(Column::from("t.bar")),
282                None,
283            ),
284            ArithmeticExpression::new(Subtract, Scalar(5.into()), ABColumn("foo".into()), None),
285            ArithmeticExpression::new(
286                Subtract,
287                Scalar(5.into()),
288                ABColumn("foo".into()),
289                Some("5_minus_foo".into()),
290            ),
291        ];
292
293        for (i, e) in exprs.iter().enumerate() {
294            let res = arithmetic_expression(e.as_bytes());
295            assert!(res.is_ok(), "{} failed to parse", e);
296            assert_eq!(res.unwrap().1, expected[i]);
297        }
298    }
299}