otter_sql/expr/
mod.rs

1//! SQL expressions and their evaluation.
2
3use std::{error::Error, fmt::Display};
4
5use sqlparser::ast;
6
7use crate::{
8    identifier::{ColumnRef, IdentifierError},
9    value::{Value, ValueError},
10    BoundedString,
11};
12
13pub mod eval;
14
15/// An expression
16#[derive(Debug, Clone, PartialEq)]
17pub enum Expr {
18    Value(Value),
19    ColumnRef(ColumnRef),
20    Wildcard,
21    Binary {
22        left: Box<Expr>,
23        op: BinOp,
24        right: Box<Expr>,
25    },
26    Unary {
27        op: UnOp,
28        operand: Box<Expr>,
29    },
30    Function {
31        name: BoundedString,
32        args: Vec<Expr>,
33    },
34}
35
36impl Display for Expr {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        match self {
39            Self::Value(v) => write!(f, "{}", v),
40            Self::ColumnRef(c) => write!(f, "column '{}'", c),
41            Self::Wildcard => write!(f, "*"),
42            Self::Binary { left, op, right } => write!(f, "({} {} {})", left, op, right),
43            Self::Unary { op, operand } => write!(f, "{}{}", op, operand),
44            Self::Function { name, args } => write!(
45                f,
46                "{}({})",
47                name,
48                args.iter()
49                    .map(|a| a.to_string())
50                    .collect::<Vec<String>>()
51                    .join(", ")
52            ),
53        }
54    }
55}
56
57/// A binary operator
58#[derive(Debug, Copy, Clone, PartialEq)]
59pub enum BinOp {
60    Plus,
61    Minus,
62    Multiply,
63    Divide,
64    Modulo,
65    Equal,
66    NotEqual,
67    LessThan,
68    LessThanOrEqual,
69    GreaterThan,
70    GreaterThanOrEqual,
71    Like,
72    ILike,
73    And,
74    Or,
75}
76
77impl Display for BinOp {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        write!(
80            f,
81            "{}",
82            match self {
83                BinOp::Plus => "+",
84                BinOp::Minus => "-",
85                BinOp::Multiply => "*",
86                BinOp::Divide => "/",
87                BinOp::Modulo => "%",
88                BinOp::Equal => "=",
89                BinOp::NotEqual => "!=",
90                BinOp::LessThan => "<",
91                BinOp::LessThanOrEqual => "<=",
92                BinOp::GreaterThan => ">",
93                BinOp::GreaterThanOrEqual => ">=",
94                BinOp::Like => "LIKE",
95                BinOp::ILike => "ILIKE",
96                BinOp::And => "AND",
97                BinOp::Or => "OR",
98            }
99        )
100    }
101}
102
103/// A unary operator
104#[derive(Debug, Copy, Clone, PartialEq)]
105pub enum UnOp {
106    Plus,
107    Minus,
108    Not,
109    IsFalse,
110    IsTrue,
111    IsNull,
112    IsNotNull,
113}
114
115impl Display for UnOp {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        write!(
118            f,
119            "{}",
120            match self {
121                UnOp::Plus => "+",
122                UnOp::Minus => "-",
123                UnOp::Not => "NOT",
124                UnOp::IsFalse => "IS FALSE",
125                UnOp::IsTrue => "IS TRUE",
126                UnOp::IsNull => "IS NULL",
127                UnOp::IsNotNull => "IS NOT NULL",
128            }
129        )
130    }
131}
132
133impl TryFrom<ast::Expr> for Expr {
134    type Error = ExprError;
135    fn try_from(expr_ast: ast::Expr) -> Result<Self, Self::Error> {
136        match expr_ast {
137            ast::Expr::Identifier(i) => Ok(Expr::ColumnRef(vec![i].try_into()?)),
138            ast::Expr::CompoundIdentifier(i) => Ok(Expr::ColumnRef(i.try_into()?)),
139            ast::Expr::IsFalse(e) => Ok(Expr::Unary {
140                op: UnOp::IsFalse,
141                operand: Box::new((*e).try_into()?),
142            }),
143            ast::Expr::IsTrue(e) => Ok(Expr::Unary {
144                op: UnOp::IsTrue,
145                operand: Box::new((*e).try_into()?),
146            }),
147            ast::Expr::IsNull(e) => Ok(Expr::Unary {
148                op: UnOp::IsNull,
149                operand: Box::new((*e).try_into()?),
150            }),
151            ast::Expr::IsNotNull(e) => Ok(Expr::Unary {
152                op: UnOp::IsNotNull,
153                operand: Box::new((*e).try_into()?),
154            }),
155            ast::Expr::Between {
156                expr,
157                negated,
158                low,
159                high,
160            } => {
161                let expr: Box<Expr> = Box::new((*expr).try_into()?);
162                let left = Box::new((*low).try_into()?);
163                let right = Box::new((*high).try_into()?);
164                let between = Expr::Binary {
165                    left: Box::new(Expr::Binary {
166                        left,
167                        op: BinOp::LessThanOrEqual,
168                        right: expr.clone(),
169                    }),
170                    op: BinOp::And,
171                    right: Box::new(Expr::Binary {
172                        left: expr,
173                        op: BinOp::LessThanOrEqual,
174                        right,
175                    }),
176                };
177                if negated {
178                    Ok(Expr::Unary {
179                        op: UnOp::Not,
180                        operand: Box::new(between),
181                    })
182                } else {
183                    Ok(between)
184                }
185            }
186            ast::Expr::BinaryOp { left, op, right } => Ok(Expr::Binary {
187                left: Box::new((*left).try_into()?),
188                op: op.try_into()?,
189                right: Box::new((*right).try_into()?),
190            }),
191            ast::Expr::UnaryOp { op, expr } => Ok(Expr::Unary {
192                op: op.try_into()?,
193                operand: Box::new((*expr).try_into()?),
194            }),
195            ast::Expr::Value(v) => Ok(Expr::Value(v.try_into()?)),
196            ast::Expr::Function(ref f) => Ok(Expr::Function {
197                name: f.name.to_string().as_str().into(),
198                args: f
199                    .args
200                    .iter()
201                    .map(|arg| match arg {
202                        ast::FunctionArg::Unnamed(arg_expr) => match arg_expr {
203                            ast::FunctionArgExpr::Expr(e) => Ok(e.clone().try_into()?),
204                            ast::FunctionArgExpr::Wildcard => Ok(Expr::Wildcard),
205                            ast::FunctionArgExpr::QualifiedWildcard(_) => Err(ExprError::Expr {
206                                reason: "Qualified wildcards are not supported yet",
207                                expr: expr_ast.clone(),
208                            }),
209                        },
210                        ast::FunctionArg::Named { .. } => Err(ExprError::Expr {
211                            reason: "Named function arguments are not supported",
212                            expr: expr_ast.clone(),
213                        }),
214                    })
215                    .collect::<Result<Vec<_>, _>>()?,
216            }),
217            _ => Err(ExprError::Expr {
218                reason: "Unsupported expression",
219                expr: expr_ast,
220            }),
221        }
222    }
223}
224
225impl TryFrom<ast::BinaryOperator> for BinOp {
226    type Error = ExprError;
227    fn try_from(op: ast::BinaryOperator) -> Result<Self, Self::Error> {
228        match op {
229            ast::BinaryOperator::Plus => Ok(BinOp::Plus),
230            ast::BinaryOperator::Minus => Ok(BinOp::Minus),
231            ast::BinaryOperator::Multiply => Ok(BinOp::Multiply),
232            ast::BinaryOperator::Divide => Ok(BinOp::Divide),
233            ast::BinaryOperator::Modulo => Ok(BinOp::Modulo),
234            ast::BinaryOperator::Eq => Ok(BinOp::Equal),
235            ast::BinaryOperator::NotEq => Ok(BinOp::NotEqual),
236            ast::BinaryOperator::Lt => Ok(BinOp::LessThan),
237            ast::BinaryOperator::LtEq => Ok(BinOp::LessThanOrEqual),
238            ast::BinaryOperator::Gt => Ok(BinOp::GreaterThan),
239            ast::BinaryOperator::GtEq => Ok(BinOp::GreaterThanOrEqual),
240            ast::BinaryOperator::Like => Ok(BinOp::Like),
241            ast::BinaryOperator::ILike => Ok(BinOp::ILike),
242            ast::BinaryOperator::And => Ok(BinOp::And),
243            ast::BinaryOperator::Or => Ok(BinOp::Or),
244            // TODO: xor?
245            _ => Err(ExprError::Binary {
246                reason: "Unknown binary operator",
247                op,
248            }),
249        }
250    }
251}
252
253impl TryFrom<ast::UnaryOperator> for UnOp {
254    type Error = ExprError;
255    fn try_from(op: ast::UnaryOperator) -> Result<Self, Self::Error> {
256        match op {
257            ast::UnaryOperator::Plus => Ok(UnOp::Plus),
258            ast::UnaryOperator::Minus => Ok(UnOp::Minus),
259            ast::UnaryOperator::Not => Ok(UnOp::Not),
260            // IsFalse, IsTrue, etc. are handled in TryFrom<ast::Expr> for Expr
261            // since `sqlparser` does not consider them unary operators for some reason.
262            _ => Err(ExprError::Unary {
263                reason: "Unkown unary operator",
264                op,
265            }),
266        }
267    }
268}
269
270/// Error in parsing an expression.
271#[derive(Debug, PartialEq)]
272pub enum ExprError {
273    Expr {
274        reason: &'static str,
275        expr: ast::Expr,
276    },
277    Binary {
278        reason: &'static str,
279        op: ast::BinaryOperator,
280    },
281    Unary {
282        reason: &'static str,
283        op: ast::UnaryOperator,
284    },
285    Value(ValueError),
286    Identifier(IdentifierError),
287}
288
289impl Display for ExprError {
290    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
291        match self {
292            ExprError::Expr { reason, expr } => {
293                write!(f, "ExprError: {}: {}", reason, expr)
294            }
295            ExprError::Binary { reason, op } => {
296                write!(f, "ExprError: {}: {}", reason, op)
297            }
298            ExprError::Unary { reason, op } => {
299                write!(f, "ExprError: {}: {}", reason, op)
300            }
301            ExprError::Value(v) => write!(f, "{}", v),
302            ExprError::Identifier(v) => write!(f, "{}", v),
303        }
304    }
305}
306
307impl From<ValueError> for ExprError {
308    fn from(v: ValueError) -> Self {
309        Self::Value(v)
310    }
311}
312
313impl From<IdentifierError> for ExprError {
314    fn from(i: IdentifierError) -> Self {
315        Self::Identifier(i)
316    }
317}
318
319impl Error for ExprError {}
320
321#[cfg(test)]
322mod tests {
323    use sqlparser::{ast, dialect::GenericDialect, parser::Parser, tokenizer::Tokenizer};
324
325    use crate::{
326        expr::{BinOp, Expr, UnOp},
327        identifier::ColumnRef,
328        value::Value,
329    };
330
331    #[test]
332    fn conversion_from_ast() {
333        fn parse_expr(s: &str) -> ast::Expr {
334            let dialect = GenericDialect {};
335            let mut tokenizer = Tokenizer::new(&dialect, s);
336            let tokens = tokenizer.tokenize().unwrap();
337            let mut parser = Parser::new(tokens, &dialect);
338            parser.parse_expr().unwrap()
339        }
340
341        assert_eq!(
342            parse_expr("abc").try_into(),
343            Ok(Expr::ColumnRef(ColumnRef {
344                schema_name: None,
345                table_name: None,
346                col_name: "abc".into()
347            }))
348        );
349
350        assert_ne!(
351            parse_expr("abc").try_into(),
352            Ok(Expr::ColumnRef(ColumnRef {
353                schema_name: None,
354                table_name: None,
355                col_name: "cab".into()
356            }))
357        );
358
359        assert_eq!(
360            parse_expr("table1.col1").try_into(),
361            Ok(Expr::ColumnRef(ColumnRef {
362                schema_name: None,
363                table_name: Some("table1".into()),
364                col_name: "col1".into()
365            }))
366        );
367
368        assert_eq!(
369            parse_expr("schema1.table1.col1").try_into(),
370            Ok(Expr::ColumnRef(ColumnRef {
371                schema_name: Some("schema1".into()),
372                table_name: Some("table1".into()),
373                col_name: "col1".into()
374            }))
375        );
376
377        assert_eq!(
378            parse_expr("5 IS NULL").try_into(),
379            Ok(Expr::Unary {
380                op: UnOp::IsNull,
381                operand: Box::new(Expr::Value(Value::Int64(5)))
382            })
383        );
384
385        assert_eq!(
386            parse_expr("1 IS TRUE").try_into(),
387            Ok(Expr::Unary {
388                op: UnOp::IsTrue,
389                operand: Box::new(Expr::Value(Value::Int64(1)))
390            })
391        );
392
393        assert_eq!(
394            parse_expr("4 BETWEEN 3 AND 5").try_into(),
395            Ok(Expr::Binary {
396                left: Box::new(Expr::Binary {
397                    left: Box::new(Expr::Value(Value::Int64(3))),
398                    op: BinOp::LessThanOrEqual,
399                    right: Box::new(Expr::Value(Value::Int64(4)))
400                }),
401                op: BinOp::And,
402                right: Box::new(Expr::Binary {
403                    left: Box::new(Expr::Value(Value::Int64(4))),
404                    op: BinOp::LessThanOrEqual,
405                    right: Box::new(Expr::Value(Value::Int64(5)))
406                })
407            })
408        );
409
410        assert_eq!(
411            parse_expr("4 NOT BETWEEN 3 AND 5").try_into(),
412            Ok(Expr::Unary {
413                op: UnOp::Not,
414                operand: Box::new(Expr::Binary {
415                    left: Box::new(Expr::Binary {
416                        left: Box::new(Expr::Value(Value::Int64(3))),
417                        op: BinOp::LessThanOrEqual,
418                        right: Box::new(Expr::Value(Value::Int64(4)))
419                    }),
420                    op: BinOp::And,
421                    right: Box::new(Expr::Binary {
422                        left: Box::new(Expr::Value(Value::Int64(4))),
423                        op: BinOp::LessThanOrEqual,
424                        right: Box::new(Expr::Value(Value::Int64(5)))
425                    })
426                })
427            })
428        );
429
430        assert_eq!(
431            parse_expr("MAX(col1)").try_into(),
432            Ok(Expr::Function {
433                name: "MAX".into(),
434                args: vec![Expr::ColumnRef(ColumnRef {
435                    schema_name: None,
436                    table_name: None,
437                    col_name: "col1".into()
438                })]
439            })
440        );
441
442        assert_eq!(
443            parse_expr("some_func(col1, 1, 'abc')").try_into(),
444            Ok(Expr::Function {
445                name: "some_func".into(),
446                args: vec![
447                    Expr::ColumnRef(ColumnRef {
448                        schema_name: None,
449                        table_name: None,
450                        col_name: "col1".into()
451                    }),
452                    Expr::Value(Value::Int64(1)),
453                    Expr::Value(Value::String("abc".to_owned()))
454                ]
455            })
456        );
457
458        assert_eq!(
459            parse_expr("COUNT(*)").try_into(),
460            Ok(Expr::Function {
461                name: "COUNT".into(),
462                args: vec![Expr::Wildcard]
463            })
464        );
465    }
466}