proof_of_sql/base/database/
expression_evaluation.rs

1use super::{ExpressionEvaluationError, ExpressionEvaluationResult};
2use crate::base::{
3    database::{OwnedColumn, OwnedTable},
4    math::{
5        decimal::{try_convert_intermediate_decimal_to_scalar, DecimalError, Precision},
6        BigDecimalExt,
7    },
8    scalar::Scalar,
9};
10use alloc::{format, string::ToString, vec};
11use proof_of_sql_parser::intermediate_ast::{Expression, Literal};
12use sqlparser::ast::{BinaryOperator, Ident, UnaryOperator};
13
14impl<S: Scalar> OwnedTable<S> {
15    /// Evaluate an expression on the table.
16    pub fn evaluate(&self, expr: &Expression) -> ExpressionEvaluationResult<OwnedColumn<S>> {
17        match expr {
18            Expression::Column(identifier) => self.evaluate_column(&Ident::from(*identifier)),
19            Expression::Literal(lit) => self.evaluate_literal(lit),
20            Expression::Binary { op, left, right } => {
21                self.evaluate_binary_expr(&(*op).into(), left, right)
22            }
23            Expression::Unary { op, expr } => self.evaluate_unary_expr((*op).into(), expr),
24            _ => Err(ExpressionEvaluationError::Unsupported {
25                expression: format!("Expression {expr:?} is not supported yet"),
26            }),
27        }
28    }
29
30    fn evaluate_column(&self, identifier: &Ident) -> ExpressionEvaluationResult<OwnedColumn<S>> {
31        Ok(self
32            .inner_table()
33            .get(identifier)
34            .ok_or(ExpressionEvaluationError::ColumnNotFound {
35                error: identifier.to_string(),
36            })?
37            .clone())
38    }
39
40    fn evaluate_literal(&self, lit: &Literal) -> ExpressionEvaluationResult<OwnedColumn<S>> {
41        let len = self.num_rows();
42        match lit {
43            Literal::Boolean(b) => Ok(OwnedColumn::Boolean(vec![*b; len])),
44            Literal::BigInt(i) => Ok(OwnedColumn::BigInt(vec![*i; len])),
45            Literal::Int128(i) => Ok(OwnedColumn::Int128(vec![*i; len])),
46            Literal::Decimal(d) => {
47                let raw_scale = d.scale();
48                let scale = raw_scale
49                    .try_into()
50                    .map_err(|_| DecimalError::InvalidScale {
51                        scale: raw_scale.to_string(),
52                    })?;
53                let precision = Precision::try_from(d.precision())?;
54                let scalar = try_convert_intermediate_decimal_to_scalar(d, precision, scale)?;
55                Ok(OwnedColumn::Decimal75(precision, scale, vec![scalar; len]))
56            }
57            Literal::VarChar(s) => Ok(OwnedColumn::VarChar(vec![s.clone(); len])),
58            Literal::Timestamp(its) => Ok(OwnedColumn::TimestampTZ(
59                its.timeunit(),
60                its.timezone(),
61                vec![its.timestamp().timestamp(); len],
62            )),
63        }
64    }
65
66    fn evaluate_unary_expr(
67        &self,
68        op: UnaryOperator,
69        expr: &Expression,
70    ) -> ExpressionEvaluationResult<OwnedColumn<S>> {
71        let column = self.evaluate(expr)?;
72        match op {
73            UnaryOperator::Not => Ok(column.element_wise_not()?),
74            // Handle unsupported unary operators
75            _ => Err(ExpressionEvaluationError::Unsupported {
76                expression: format!("Unary operator '{op}' is not supported."),
77            }),
78        }
79    }
80
81    fn evaluate_binary_expr(
82        &self,
83        op: &BinaryOperator,
84        left: &Expression,
85        right: &Expression,
86    ) -> ExpressionEvaluationResult<OwnedColumn<S>> {
87        let left = self.evaluate(left)?;
88        let right = self.evaluate(right)?;
89        match op {
90            BinaryOperator::And => Ok(left.element_wise_and(&right)?),
91            BinaryOperator::Or => Ok(left.element_wise_or(&right)?),
92            BinaryOperator::Eq => Ok(left.element_wise_eq(&right)?),
93            BinaryOperator::Gt => Ok(left.element_wise_gt(&right)?),
94            BinaryOperator::Lt => Ok(left.element_wise_lt(&right)?),
95            BinaryOperator::Plus => Ok(left.element_wise_add(&right)?),
96            BinaryOperator::Minus => Ok(left.element_wise_sub(&right)?),
97            BinaryOperator::Multiply => Ok(left.element_wise_mul(&right)?),
98            BinaryOperator::Divide => Ok(left.element_wise_div(&right)?),
99            _ => Err(ExpressionEvaluationError::Unsupported {
100                expression: format!("Binary operator '{op}' is not supported."),
101            }),
102        }
103    }
104}