1use nom::character::complete::{multispace0, multispace1};
2use std::{fmt, str};
3
4use column::Column;
5use common::{
6 as_alias, column_identifier_no_alias, integer_literal, type_identifier, Literal, SqlType,
7};
8use nom::branch::alt;
9use nom::bytes::complete::{tag, tag_no_case};
10use nom::combinator::{map, opt};
11use nom::sequence::{terminated, tuple};
12use nom::IResult;
13
14#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
15pub enum ArithmeticOperator {
16 Add,
17 Subtract,
18 Multiply,
19 Divide,
20}
21
22#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
23pub enum ArithmeticBase {
24 Column(Column),
25 Scalar(Literal),
26}
27
28#[derive(Debug, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)]
29pub struct ArithmeticExpression {
30 pub op: ArithmeticOperator,
31 pub left: ArithmeticBase,
32 pub right: ArithmeticBase,
33 pub alias: Option<String>,
34}
35
36impl ArithmeticExpression {
37 pub fn new(
38 op: ArithmeticOperator,
39 left: ArithmeticBase,
40 right: ArithmeticBase,
41 alias: Option<String>,
42 ) -> Self {
43 Self {
44 op,
45 left,
46 right,
47 alias,
48 }
49 }
50}
51
52impl fmt::Display for ArithmeticOperator {
53 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54 match *self {
55 ArithmeticOperator::Add => write!(f, "+"),
56 ArithmeticOperator::Subtract => write!(f, "-"),
57 ArithmeticOperator::Multiply => write!(f, "*"),
58 ArithmeticOperator::Divide => write!(f, "/"),
59 }
60 }
61}
62
63impl fmt::Display for ArithmeticBase {
64 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65 match *self {
66 ArithmeticBase::Column(ref col) => write!(f, "{}", col),
67 ArithmeticBase::Scalar(ref lit) => write!(f, "{}", lit.to_string()),
68 }
69 }
70}
71
72impl fmt::Display for ArithmeticExpression {
73 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
74 match self.alias {
75 Some(ref alias) => write!(f, "{} {} {} AS {}", self.left, self.op, self.right, alias),
76 None => write!(f, "{} {} {}", self.left, self.op, self.right),
77 }
78 }
79}
80
81fn arithmetic_cast_helper(i: &[u8]) -> IResult<&[u8], (ArithmeticBase, Option<SqlType>)> {
82 let (remaining_input, (_, _, _, _, a_base, _, _, _, _sign, sql_type, _, _)) = tuple((
83 tag_no_case("cast"),
84 multispace0,
85 tag("("),
86 multispace0,
87 arithmetic_base,
89 multispace1,
90 tag_no_case("as"),
91 multispace1,
92 opt(terminated(tag_no_case("signed"), multispace1)),
93 type_identifier,
94 multispace0,
95 tag(")"),
96 ))(i)?;
97
98 Ok((remaining_input, (a_base, Some(sql_type))))
99}
100
101pub fn arithmetic_cast(i: &[u8]) -> IResult<&[u8], (ArithmeticBase, Option<SqlType>)> {
102 alt((arithmetic_cast_helper, map(arithmetic_base, |v| (v, None))))(i)
103}
104
105pub fn arithmetic_operator(i: &[u8]) -> IResult<&[u8], ArithmeticOperator> {
108 alt((
109 map(tag("+"), |_| ArithmeticOperator::Add),
110 map(tag("-"), |_| ArithmeticOperator::Subtract),
111 map(tag("*"), |_| ArithmeticOperator::Multiply),
112 map(tag("/"), |_| ArithmeticOperator::Divide),
113 ))(i)
114}
115
116pub fn arithmetic_base(i: &[u8]) -> IResult<&[u8], ArithmeticBase> {
118 alt((
119 map(integer_literal, |il| ArithmeticBase::Scalar(il)),
120 map(column_identifier_no_alias, |ci| ArithmeticBase::Column(ci)),
121 ))(i)
122}
123
124pub fn arithmetic_expression(i: &[u8]) -> IResult<&[u8], ArithmeticExpression> {
127 let (remaining_input, (left, _, op, _, right, opt_alias)) = tuple((
128 arithmetic_cast,
129 multispace0,
130 arithmetic_operator,
131 multispace0,
132 arithmetic_cast,
133 opt(as_alias),
134 ))(i)?;
135
136 let alias = match opt_alias {
137 None => None,
138 Some(a) => Some(String::from(a)),
139 };
140
141 Ok((
142 remaining_input,
143 ArithmeticExpression {
144 left: left.0,
145 right: right.0,
146 op,
147 alias,
148 },
149 ))
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn it_parses_arithmetic_expressions() {
158 use super::ArithmeticBase::Column as ABColumn;
159 use super::ArithmeticBase::Scalar;
160 use super::ArithmeticOperator::*;
161 use column::{FunctionArguments, FunctionExpression};
162
163 let lit_ae = [
164 "5 + 42",
165 "5+42",
166 "5 * 42",
167 "5 - 42",
168 "5 / 42",
169 "2 * 10 AS twenty ",
170 ];
171
172 let col_lit_ae = [
176 "foo+5",
177 "foo + 5",
178 "5 + foo ",
179 "foo * bar AS foobar",
180 "MAX(foo)-3333",
181 ];
182
183 let expected_lit_ae = [
184 ArithmeticExpression::new(Add, Scalar(5.into()), Scalar(42.into()), None),
185 ArithmeticExpression::new(Add, Scalar(5.into()), Scalar(42.into()), None),
186 ArithmeticExpression::new(Multiply, Scalar(5.into()), Scalar(42.into()), None),
187 ArithmeticExpression::new(Subtract, Scalar(5.into()), Scalar(42.into()), None),
188 ArithmeticExpression::new(Divide, Scalar(5.into()), Scalar(42.into()), None),
189 ArithmeticExpression::new(
190 Multiply,
191 Scalar(2.into()),
192 Scalar(10.into()),
193 Some(String::from("twenty")),
194 ),
195 ];
196 let expected_col_lit_ae = [
197 ArithmeticExpression::new(Add, ABColumn("foo".into()), Scalar(5.into()), None),
198 ArithmeticExpression::new(Add, ABColumn("foo".into()), Scalar(5.into()), None),
199 ArithmeticExpression::new(Add, Scalar(5.into()), ABColumn("foo".into()), None),
200 ArithmeticExpression::new(
201 Multiply,
202 ABColumn("foo".into()),
203 ABColumn("bar".into()),
204 Some(String::from("foobar")),
205 ),
206 ArithmeticExpression::new(
207 Subtract,
208 ABColumn(Column {
209 name: String::from("max(foo)"),
210 alias: None,
211 table: None,
212 function: Some(Box::new(FunctionExpression::Max(
213 FunctionArguments::Column("foo".into()),
214 ))),
215 }),
216 Scalar(3333.into()),
217 None,
218 ),
219 ];
220
221 for (i, e) in lit_ae.iter().enumerate() {
222 let res = arithmetic_expression(e.as_bytes());
223 assert!(res.is_ok());
224 assert_eq!(res.unwrap().1, expected_lit_ae[i]);
225 }
226
227 for (i, e) in col_lit_ae.iter().enumerate() {
228 let res = arithmetic_expression(e.as_bytes());
229 assert!(res.is_ok());
230 assert_eq!(res.unwrap().1, expected_col_lit_ae[i]);
231 }
232 }
233
234 #[test]
235 fn it_displays_arithmetic_expressions() {
236 use super::ArithmeticBase::Column as ABColumn;
237 use super::ArithmeticBase::Scalar;
238 use super::ArithmeticOperator::*;
239
240 let expressions = [
241 ArithmeticExpression::new(Add, ABColumn("foo".into()), Scalar(5.into()), None),
242 ArithmeticExpression::new(Subtract, Scalar(5.into()), ABColumn("foo".into()), None),
243 ArithmeticExpression::new(
244 Multiply,
245 ABColumn("foo".into()),
246 ABColumn("bar".into()),
247 None,
248 ),
249 ArithmeticExpression::new(Divide, Scalar(10.into()), Scalar(2.into()), None),
250 ArithmeticExpression::new(
251 Add,
252 Scalar(10.into()),
253 Scalar(2.into()),
254 Some(String::from("bob")),
255 ),
256 ];
257
258 let expected_strings = ["foo + 5", "5 - foo", "foo * bar", "10 / 2", "10 + 2 AS bob"];
259 for (i, e) in expressions.iter().enumerate() {
260 assert_eq!(expected_strings[i], format!("{}", e));
261 }
262 }
263
264 #[test]
265 fn it_parses_arithmetic_casts() {
266 use super::ArithmeticBase::Column as ABColumn;
267 use super::ArithmeticBase::Scalar;
268 use super::ArithmeticOperator::*;
269
270 let exprs = [
271 "CAST(`t`.`foo` AS signed int) + CAST(`t`.`bar` AS signed int) ",
272 "CAST(5 AS bigint) - foo ",
273 "CAST(5 AS bigint) - foo AS 5_minus_foo",
274 ];
275
276 let expected = [
278 ArithmeticExpression::new(
279 Add,
280 ABColumn(Column::from("t.foo")),
281 ABColumn(Column::from("t.bar")),
282 None,
283 ),
284 ArithmeticExpression::new(Subtract, Scalar(5.into()), ABColumn("foo".into()), None),
285 ArithmeticExpression::new(
286 Subtract,
287 Scalar(5.into()),
288 ABColumn("foo".into()),
289 Some("5_minus_foo".into()),
290 ),
291 ];
292
293 for (i, e) in exprs.iter().enumerate() {
294 let res = arithmetic_expression(e.as_bytes());
295 assert!(res.is_ok(), "{} failed to parse", e);
296 assert_eq!(res.unwrap().1, expected[i]);
297 }
298 }
299}