otter_sql/expr/
eval.rs

1//! Evaluator of expressions.
2
3use std::{error::Error, fmt::Display};
4
5use crate::{
6    expr::{BinOp, Expr, UnOp},
7    identifier::BoundedString,
8    table::{RowLike, RowShared, Table},
9    value::{Value, ValueBinaryOpError, ValueUnaryOpError},
10};
11
12impl Expr {
13    pub fn execute(expr: &Expr, table: &Table, row: RowShared) -> Result<Value, ExprExecError> {
14        match expr {
15            Expr::Value(v) => Ok(v.to_owned()),
16            Expr::Binary {
17                left,
18                op: BinOp::And,
19                right,
20            } => {
21                let left = Expr::execute(left, table, row.clone())?;
22                let right = Expr::execute(right, table, row)?;
23
24                match (&left, &right) {
25                    (Value::Bool(left), Value::Bool(right)) => Ok(Value::Bool(*left && *right)),
26                    _ => Err(ExprExecError::ValueBinaryOpError(ValueBinaryOpError {
27                        operator: BinOp::And,
28                        values: (left, right),
29                    })),
30                }
31            }
32            Expr::Binary {
33                left,
34                op: BinOp::Or,
35                right,
36            } => {
37                let left = Expr::execute(left, table, row.clone())?;
38                let right = Expr::execute(right, table, row)?;
39
40                match (&left, &right) {
41                    (Value::Bool(left), Value::Bool(right)) => Ok(Value::Bool(*left || *right)),
42                    _ => Err(ExprExecError::ValueBinaryOpError(ValueBinaryOpError {
43                        operator: BinOp::Or,
44                        values: (left, right),
45                    })),
46                }
47            }
48            Expr::Binary { left, op, right } => {
49                let left = Expr::execute(left, table, row.clone())?;
50                let right = Expr::execute(right, table, row)?;
51                Ok(match op {
52                    BinOp::Plus => left + right,
53                    BinOp::Minus => left - right,
54                    BinOp::Multiply => left * right,
55                    BinOp::Divide => left / right,
56                    BinOp::Modulo => left % right,
57                    BinOp::Equal => Ok(Value::Bool(left == right)),
58                    BinOp::NotEqual => Ok(Value::Bool(left != right)),
59                    BinOp::LessThan => Ok(Value::Bool(left < right)),
60                    BinOp::LessThanOrEqual => Ok(Value::Bool(left <= right)),
61                    BinOp::GreaterThan => Ok(Value::Bool(left > right)),
62                    BinOp::GreaterThanOrEqual => Ok(Value::Bool(left >= right)),
63                    BinOp::Like => left.like(right),
64                    BinOp::ILike => left.ilike(right),
65                    BinOp::And | BinOp::Or => {
66                        unreachable!("AND and OR should be handled separately")
67                    }
68                }?)
69            }
70            Expr::Unary { op, operand } => {
71                let operand = Expr::execute(operand, table, row)?;
72                Ok(match op {
73                    UnOp::Plus => Ok(operand),
74                    UnOp::Minus => -operand,
75                    UnOp::Not => !operand,
76                    UnOp::IsFalse => operand.is_false(),
77                    UnOp::IsTrue => operand.is_true(),
78                    UnOp::IsNull => operand.is_null(),
79                    UnOp::IsNotNull => operand.is_not_null(),
80                }?)
81            }
82            Expr::Wildcard => Err(ExprExecError::CannotExecute(expr.to_owned())),
83            Expr::ColumnRef(col_ref) => {
84                let col_index = if let Some(col_index) =
85                    table.columns().position(|c| c.name() == &col_ref.col_name)
86                {
87                    col_index
88                } else {
89                    // TODO: show table name here too
90                    // and think of how it will work for JOINs and temp tables
91                    return Err(ExprExecError::NoSuchColumn(col_ref.col_name));
92                };
93                if let Some(val) = row.data().get(col_index) {
94                    Ok(val.clone())
95                } else {
96                    // TODO: show the row here too
97                    return Err(ExprExecError::CorruptedData {
98                        col_name: col_ref.col_name,
99                        table_name: *table.name(),
100                    });
101                }
102            }
103            // TODO: functions
104            Expr::Function { name: _, args: _ } => todo!(),
105        }
106    }
107}
108
109/// Error in execution of an expression.
110#[derive(Debug, PartialEq)]
111pub enum ExprExecError {
112    CannotExecute(Expr),
113    ValueBinaryOpError(ValueBinaryOpError),
114    ValueUnaryOpError(ValueUnaryOpError),
115    NoSuchColumn(BoundedString),
116    CorruptedData {
117        col_name: BoundedString,
118        table_name: BoundedString,
119    },
120}
121
122impl From<ValueBinaryOpError> for ExprExecError {
123    fn from(e: ValueBinaryOpError) -> Self {
124        Self::ValueBinaryOpError(e)
125    }
126}
127
128impl From<ValueUnaryOpError> for ExprExecError {
129    fn from(e: ValueUnaryOpError) -> Self {
130        Self::ValueUnaryOpError(e)
131    }
132}
133
134impl Display for ExprExecError {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        match self {
137            Self::CannotExecute(expr) => write!(f, "ExprExecError: cannot execute '{}'", expr),
138            Self::ValueBinaryOpError(e) => write!(f, "ExprExecError: {}", e),
139            Self::ValueUnaryOpError(e) => write!(f, "ExprExecError: {}", e),
140            Self::NoSuchColumn(col_name) => {
141                write!(f, "ExprExecError: no such column '{}'", col_name)
142            }
143            Self::CorruptedData {
144                col_name,
145                table_name,
146            } => write!(
147                f,
148                "ExprExecError: data is corrupted for column '{}' of table '{}'",
149                col_name, table_name
150            ),
151        }
152    }
153}
154
155impl Error for ExprExecError {}
156
157#[cfg(test)]
158mod test {
159    use sqlparser::{
160        ast::{ColumnOption, ColumnOptionDef, DataType},
161        dialect::GenericDialect,
162        parser::Parser,
163        tokenizer::Tokenizer,
164    };
165
166    use crate::{
167        column::Column,
168        expr::{BinOp, Expr, UnOp},
169        table::{Row, Table},
170        value::{Value, ValueBinaryOpError, ValueUnaryOpError},
171    };
172
173    use super::ExprExecError;
174
175    fn str_to_expr(s: &str) -> Expr {
176        let dialect = GenericDialect {};
177        let mut tokenizer = Tokenizer::new(&dialect, s);
178        let tokens = tokenizer.tokenize().unwrap();
179        let mut parser = Parser::new(tokens, &dialect);
180        parser.parse_expr().unwrap().try_into().unwrap()
181    }
182
183    fn exec_expr_no_context(expr: Expr) -> Result<Value, ExprExecError> {
184        let mut table = Table::new_temp(0);
185        table.new_row(vec![]);
186        Expr::execute(&expr, &table, table.all_data()[0].to_shared())
187    }
188
189    fn exec_str_no_context(s: &str) -> Result<Value, ExprExecError> {
190        let expr = str_to_expr(s);
191        exec_expr_no_context(expr)
192    }
193
194    fn exec_str_with_context(s: &str, table: &Table, row: &Row) -> Result<Value, ExprExecError> {
195        let expr = str_to_expr(s);
196        Expr::execute(&expr, table, row.to_shared())
197    }
198
199    #[test]
200    fn exec_value() {
201        assert_eq!(exec_str_no_context("NULL"), Ok(Value::Null));
202
203        assert_eq!(exec_str_no_context("true"), Ok(Value::Bool(true)));
204
205        assert_eq!(exec_str_no_context("1"), Ok(Value::Int64(1)));
206
207        assert_eq!(exec_str_no_context("1.1"), Ok(Value::Float64(1.1.into())));
208
209        assert_eq!(exec_str_no_context(".1"), Ok(Value::Float64(0.1.into())));
210
211        assert_eq!(
212            exec_str_no_context("'str'"),
213            Ok(Value::String("str".to_owned()))
214        );
215    }
216
217    #[test]
218    fn exec_logical() {
219        assert_eq!(exec_str_no_context("true and true"), Ok(Value::Bool(true)));
220        assert_eq!(
221            exec_str_no_context("true and false"),
222            Ok(Value::Bool(false))
223        );
224        assert_eq!(
225            exec_str_no_context("false and true"),
226            Ok(Value::Bool(false))
227        );
228        assert_eq!(
229            exec_str_no_context("false and false"),
230            Ok(Value::Bool(false))
231        );
232        assert_eq!(
233            exec_str_no_context("false and 10"),
234            Err(ValueBinaryOpError {
235                operator: BinOp::And,
236                values: (Value::Bool(false), Value::Int64(10))
237            }
238            .into())
239        );
240        assert_eq!(
241            exec_str_no_context("10 and false"),
242            Err(ValueBinaryOpError {
243                operator: BinOp::And,
244                values: (Value::Int64(10), Value::Bool(false))
245            }
246            .into())
247        );
248
249        assert_eq!(exec_str_no_context("true or true"), Ok(Value::Bool(true)));
250        assert_eq!(exec_str_no_context("true or false"), Ok(Value::Bool(true)));
251        assert_eq!(exec_str_no_context("false or true"), Ok(Value::Bool(true)));
252        assert_eq!(
253            exec_str_no_context("false or false"),
254            Ok(Value::Bool(false))
255        );
256        assert_eq!(
257            exec_str_no_context("true or 10"),
258            Err(ValueBinaryOpError {
259                operator: BinOp::Or,
260                values: (Value::Bool(true), Value::Int64(10))
261            }
262            .into())
263        );
264        assert_eq!(
265            exec_str_no_context("10 or true"),
266            Err(ValueBinaryOpError {
267                operator: BinOp::Or,
268                values: (Value::Int64(10), Value::Bool(true))
269            }
270            .into())
271        );
272    }
273
274    #[test]
275    fn exec_arithmetic() {
276        assert_eq!(exec_str_no_context("1 + 1"), Ok(Value::Int64(2)));
277        assert_eq!(
278            exec_str_no_context("1.1 + 1.1"),
279            Ok(Value::Float64(2.2.into()))
280        );
281
282        // this applies to all binary ops
283        assert_eq!(
284            exec_str_no_context("1 + 1.1"),
285            Err(ValueBinaryOpError {
286                operator: BinOp::Plus,
287                values: (Value::Int64(1), Value::Float64(1.1.into()))
288            }
289            .into())
290        );
291
292        assert_eq!(exec_str_no_context("4 - 2"), Ok(Value::Int64(2)));
293        assert_eq!(exec_str_no_context("4 - 6"), Ok(Value::Int64(-2)));
294        assert_eq!(
295            exec_str_no_context("4.5 - 2.2"),
296            Ok(Value::Float64(2.3.into()))
297        );
298
299        assert_eq!(exec_str_no_context("4 * 2"), Ok(Value::Int64(8)));
300        assert_eq!(
301            exec_str_no_context("0.5 * 2.2"),
302            Ok(Value::Float64(1.1.into()))
303        );
304
305        assert_eq!(exec_str_no_context("4 / 2"), Ok(Value::Int64(2)));
306        assert_eq!(exec_str_no_context("4 / 3"), Ok(Value::Int64(1)));
307        assert_eq!(
308            exec_str_no_context("4.0 / 2.0"),
309            Ok(Value::Float64(2.0.into()))
310        );
311        assert_eq!(
312            exec_str_no_context("5.1 / 2.5"),
313            Ok(Value::Float64(2.04.into()))
314        );
315
316        assert_eq!(exec_str_no_context("5 % 2"), Ok(Value::Int64(1)));
317        assert_eq!(
318            exec_str_no_context("5.5 % 2.5"),
319            Ok(Value::Float64(0.5.into()))
320        );
321    }
322
323    #[test]
324    fn exec_comparison() {
325        assert_eq!(exec_str_no_context("1 = 1"), Ok(Value::Bool(true)));
326        assert_eq!(exec_str_no_context("1 = 2"), Ok(Value::Bool(false)));
327        assert_eq!(exec_str_no_context("1 != 2"), Ok(Value::Bool(true)));
328        assert_eq!(exec_str_no_context("1.1 = 1.1"), Ok(Value::Bool(true)));
329        assert_eq!(exec_str_no_context("1.2 = 1.22"), Ok(Value::Bool(false)));
330        assert_eq!(exec_str_no_context("1.2 != 1.22"), Ok(Value::Bool(true)));
331
332        assert_eq!(exec_str_no_context("1 < 2"), Ok(Value::Bool(true)));
333        assert_eq!(exec_str_no_context("1 < 1"), Ok(Value::Bool(false)));
334        assert_eq!(exec_str_no_context("1 <= 2"), Ok(Value::Bool(true)));
335        assert_eq!(exec_str_no_context("1 <= 1"), Ok(Value::Bool(true)));
336        assert_eq!(exec_str_no_context("3 > 2"), Ok(Value::Bool(true)));
337        assert_eq!(exec_str_no_context("3 > 3"), Ok(Value::Bool(false)));
338        assert_eq!(exec_str_no_context("3 >= 2"), Ok(Value::Bool(true)));
339        assert_eq!(exec_str_no_context("3 >= 3"), Ok(Value::Bool(true)));
340    }
341
342    #[test]
343    fn exec_pattern_match() {
344        assert_eq!(
345            exec_str_no_context("'my name is yoshikage kira' LIKE 'kira'"),
346            Ok(Value::Bool(true))
347        );
348        assert_eq!(
349            exec_str_no_context("'my name is yoshikage kira' LIKE 'KIRA'"),
350            Ok(Value::Bool(false))
351        );
352        assert_eq!(
353            exec_str_no_context("'my name is yoshikage kira' LIKE 'kira yoshikage'"),
354            Ok(Value::Bool(false))
355        );
356
357        assert_eq!(
358            exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'kira'"),
359            Ok(Value::Bool(true))
360        );
361        assert_eq!(
362            exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'KIRA'"),
363            Ok(Value::Bool(true))
364        );
365        assert_eq!(
366            exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'KIRAA'"),
367            Ok(Value::Bool(false))
368        );
369    }
370
371    #[test]
372    fn exec_unary() {
373        assert_eq!(exec_str_no_context("+1"), Ok(Value::Int64(1)));
374        assert_eq!(exec_str_no_context("+ -1"), Ok(Value::Int64(-1)));
375        assert_eq!(exec_str_no_context("-1"), Ok(Value::Int64(-1)));
376        assert_eq!(exec_str_no_context("- -1"), Ok(Value::Int64(1)));
377        assert_eq!(exec_str_no_context("not true"), Ok(Value::Bool(false)));
378        assert_eq!(exec_str_no_context("not false"), Ok(Value::Bool(true)));
379
380        assert_eq!(exec_str_no_context("true is true"), Ok(Value::Bool(true)));
381        assert_eq!(exec_str_no_context("false is false"), Ok(Value::Bool(true)));
382        assert_eq!(exec_str_no_context("false is true"), Ok(Value::Bool(false)));
383        assert_eq!(exec_str_no_context("true is false"), Ok(Value::Bool(false)));
384        assert_eq!(
385            exec_str_no_context("1 is true"),
386            Err(ValueUnaryOpError {
387                operator: UnOp::IsTrue,
388                value: Value::Int64(1)
389            }
390            .into())
391        );
392
393        assert_eq!(exec_str_no_context("NULL is NULL"), Ok(Value::Bool(true)));
394        assert_eq!(
395            exec_str_no_context("NULL is not NULL"),
396            Ok(Value::Bool(false))
397        );
398        assert_eq!(exec_str_no_context("1 is NULL"), Ok(Value::Bool(false)));
399        assert_eq!(exec_str_no_context("1 is not NULL"), Ok(Value::Bool(true)));
400        assert_eq!(exec_str_no_context("0 is not NULL"), Ok(Value::Bool(true)));
401        assert_eq!(exec_str_no_context("'' is not NULL"), Ok(Value::Bool(true)));
402    }
403
404    #[test]
405    fn exec_wildcard() {
406        assert_eq!(
407            exec_expr_no_context(Expr::Wildcard),
408            Err(ExprExecError::CannotExecute(Expr::Wildcard))
409        );
410    }
411
412    #[test]
413    fn exec_column_ref() {
414        let mut table = Table::new(
415            "table1".into(),
416            vec![
417                Column::new(
418                    "col1".into(),
419                    DataType::Int(None),
420                    vec![ColumnOptionDef {
421                        name: None,
422                        option: ColumnOption::Unique { is_primary: true },
423                    }],
424                    false,
425                ),
426                Column::new(
427                    "col2".into(),
428                    DataType::Int(None),
429                    vec![ColumnOptionDef {
430                        name: None,
431                        option: ColumnOption::Unique { is_primary: false },
432                    }],
433                    false,
434                ),
435                Column::new("col3".into(), DataType::String, vec![], false),
436            ],
437        );
438        table.new_row(vec![
439            Value::Int64(4),
440            Value::Int64(10),
441            Value::String("brr".to_owned()),
442        ]);
443
444        assert_eq!(
445            table.all_data(),
446            vec![Row::new(vec![
447                Value::Int64(4),
448                Value::Int64(10),
449                Value::String("brr".to_owned())
450            ])]
451        );
452
453        assert_eq!(
454            exec_str_with_context("col1", &table, &table.all_data()[0]),
455            Ok(Value::Int64(4))
456        );
457
458        assert_eq!(
459            exec_str_with_context("col3", &table, &table.all_data()[0]),
460            Ok(Value::String("brr".to_owned()))
461        );
462
463        assert_eq!(
464            exec_str_with_context("col1 = 4", &table, &table.all_data()[0]),
465            Ok(Value::Bool(true))
466        );
467
468        assert_eq!(
469            exec_str_with_context("col1 + 1", &table, &table.all_data()[0]),
470            Ok(Value::Int64(5))
471        );
472
473        assert_eq!(
474            exec_str_with_context("col1 + col2", &table, &table.all_data()[0]),
475            Ok(Value::Int64(14))
476        );
477
478        assert_eq!(
479            exec_str_with_context(
480                "col1 + col2 = 10 or col1 * col2 = 40",
481                &table,
482                &table.all_data()[0]
483            ),
484            Ok(Value::Bool(true))
485        );
486    }
487}