quill_sql/expression/
binary.rs

1use crate::catalog::Schema;
2use crate::catalog::{Column, DataType};
3use crate::error::QuillSQLError;
4use crate::error::QuillSQLResult;
5use crate::expression::{Expr, ExprTrait};
6use crate::storage::tuple::Tuple;
7use crate::utils::scalar::ScalarValue;
8use std::cmp::Ordering;
9
10fn numeric_binary_op<F>(left: ScalarValue, right: ScalarValue, op: F) -> QuillSQLResult<ScalarValue>
11where
12    F: Fn(ScalarValue, ScalarValue) -> QuillSQLResult<ScalarValue>,
13{
14    let coercion_type =
15        DataType::comparison_numeric_coercion(&left.data_type(), &right.data_type())?;
16    let l_cast = left.cast_to(&coercion_type)?;
17    let r_cast = right.cast_to(&coercion_type)?;
18    op(l_cast, r_cast)
19}
20
21/// Binary expression
22#[derive(Clone, PartialEq, Eq, Debug)]
23pub struct BinaryExpr {
24    /// Left-hand side of the expression
25    pub left: Box<Expr>,
26    /// The comparison operator
27    pub op: BinaryOp,
28    /// Right-hand side of the expression
29    pub right: Box<Expr>,
30}
31
32impl ExprTrait for BinaryExpr {
33    fn data_type(&self, input_schema: &Schema) -> QuillSQLResult<DataType> {
34        let left_type = self.left.data_type(input_schema)?;
35        let right_type = self.right.data_type(input_schema)?;
36        match self.op {
37            BinaryOp::Gt
38            | BinaryOp::Lt
39            | BinaryOp::GtEq
40            | BinaryOp::LtEq
41            | BinaryOp::Eq
42            | BinaryOp::NotEq
43            | BinaryOp::And
44            | BinaryOp::Or => Ok(DataType::Boolean),
45            BinaryOp::Plus | BinaryOp::Minus | BinaryOp::Multiply | BinaryOp::Divide => {
46                DataType::comparison_numeric_coercion(&left_type, &right_type)
47            }
48        }
49    }
50
51    fn nullable(&self, input_schema: &Schema) -> QuillSQLResult<bool> {
52        Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?)
53    }
54
55    fn evaluate(&self, tuple: &Tuple) -> QuillSQLResult<ScalarValue> {
56        let l = self.left.evaluate(tuple)?;
57        let r = self.right.evaluate(tuple)?;
58        match self.op {
59            BinaryOp::Gt => evaluate_comparison(l, r, &[Ordering::Greater]),
60            BinaryOp::Lt => evaluate_comparison(l, r, &[Ordering::Less]),
61            BinaryOp::GtEq => evaluate_comparison(l, r, &[Ordering::Greater, Ordering::Equal]),
62            BinaryOp::LtEq => evaluate_comparison(l, r, &[Ordering::Less, Ordering::Equal]),
63            BinaryOp::Eq => evaluate_comparison(l, r, &[Ordering::Equal]),
64            BinaryOp::NotEq => evaluate_comparison(l, r, &[Ordering::Greater, Ordering::Less]),
65            BinaryOp::And => {
66                let l_bool = l.as_boolean()?;
67                let r_bool = r.as_boolean()?;
68                Ok(ScalarValue::Boolean(Some(
69                    l_bool.unwrap_or(false) && r_bool.unwrap_or(false),
70                )))
71            }
72            BinaryOp::Or => {
73                let l_bool = l.as_boolean()?;
74                let r_bool = r.as_boolean()?;
75                Ok(ScalarValue::Boolean(Some(
76                    l_bool.unwrap_or(false) || r_bool.unwrap_or(false),
77                )))
78            }
79            BinaryOp::Plus => numeric_binary_op(l, r, |a, b| a.wrapping_add(b)),
80            BinaryOp::Minus => numeric_binary_op(l, r, |a, b| a.wrapping_sub(b)),
81            BinaryOp::Multiply => numeric_binary_op(l, r, |a, b| a.wrapping_mul(b)),
82            BinaryOp::Divide => numeric_binary_op(l, r, |a, b| a.wrapping_div(b)),
83        }
84    }
85
86    fn to_column(&self, input_schema: &Schema) -> QuillSQLResult<Column> {
87        Ok(Column::new(
88            format!("{self}"),
89            self.data_type(input_schema)?,
90            self.nullable(input_schema)?,
91        ))
92    }
93}
94
95fn evaluate_comparison(
96    left: ScalarValue,
97    right: ScalarValue,
98    accepted_orderings: &[Ordering],
99) -> QuillSQLResult<ScalarValue> {
100    let coercion_type =
101        DataType::comparison_numeric_coercion(&left.data_type(), &right.data_type())?;
102    let order = left
103        .cast_to(&coercion_type)?
104        .partial_cmp(&right.cast_to(&coercion_type)?)
105        .ok_or(QuillSQLError::Execution(format!(
106            "Can not compare {:?} and {:?}",
107            left, right
108        )))?;
109    Ok(ScalarValue::Boolean(Some(
110        accepted_orderings.contains(&order),
111    )))
112}
113
114#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
115pub enum BinaryOp {
116    Plus,
117    Minus,
118    Multiply,
119    Divide,
120    Gt,
121    Lt,
122    GtEq,
123    LtEq,
124    Eq,
125    NotEq,
126    And,
127    Or,
128}
129
130impl TryFrom<&sqlparser::ast::BinaryOperator> for BinaryOp {
131    type Error = QuillSQLError;
132
133    fn try_from(value: &sqlparser::ast::BinaryOperator) -> Result<Self, Self::Error> {
134        match value {
135            sqlparser::ast::BinaryOperator::Plus => Ok(BinaryOp::Plus),
136            sqlparser::ast::BinaryOperator::Minus => Ok(BinaryOp::Minus),
137            sqlparser::ast::BinaryOperator::Multiply => Ok(BinaryOp::Multiply),
138            sqlparser::ast::BinaryOperator::Divide => Ok(BinaryOp::Divide),
139            sqlparser::ast::BinaryOperator::Gt => Ok(BinaryOp::Gt),
140            sqlparser::ast::BinaryOperator::Lt => Ok(BinaryOp::Lt),
141            sqlparser::ast::BinaryOperator::GtEq => Ok(BinaryOp::GtEq),
142            sqlparser::ast::BinaryOperator::LtEq => Ok(BinaryOp::LtEq),
143            sqlparser::ast::BinaryOperator::Eq => Ok(BinaryOp::Eq),
144            sqlparser::ast::BinaryOperator::NotEq => Ok(BinaryOp::NotEq),
145            sqlparser::ast::BinaryOperator::And => Ok(BinaryOp::And),
146            sqlparser::ast::BinaryOperator::Or => Ok(BinaryOp::Or),
147            _ => Err(QuillSQLError::NotSupport(format!(
148                "sqlparser binary operator {} not supported",
149                value
150            ))),
151        }
152    }
153}
154
155impl std::fmt::Display for BinaryExpr {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        write!(f, "({} {} {})", self.left, self.op, self.right)
158    }
159}
160
161impl std::fmt::Display for BinaryOp {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        write!(f, "{self:?}")
164    }
165}