use std::collections::HashSet;
use std::fmt::{Display, Formatter};
use itertools::Itertools;
pub use self::scalars::{ArrayData, Scalar, StructData};
mod scalars;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum BinaryOperator {
Plus,
Minus,
Multiply,
Divide,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
Equal,
NotEqual,
Distinct,
In,
NotIn,
}
impl BinaryOperator {
pub(crate) fn commute(&self) -> Option<BinaryOperator> {
use BinaryOperator::*;
match self {
GreaterThan => Some(LessThan),
GreaterThanOrEqual => Some(LessThanOrEqual),
LessThan => Some(GreaterThan),
LessThanOrEqual => Some(GreaterThanOrEqual),
Equal | NotEqual | Plus | Multiply => Some(self.clone()),
_ => None,
}
}
pub(crate) fn invert(&self) -> Option<BinaryOperator> {
use BinaryOperator::*;
match self {
LessThan => Some(GreaterThanOrEqual),
LessThanOrEqual => Some(GreaterThan),
GreaterThan => Some(LessThanOrEqual),
GreaterThanOrEqual => Some(LessThan),
Equal => Some(NotEqual),
NotEqual => Some(Equal),
In => Some(NotIn),
NotIn => Some(In),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum VariadicOperator {
And,
Or,
}
impl VariadicOperator {
pub(crate) fn invert(&self) -> VariadicOperator {
use VariadicOperator::*;
match self {
And => Or,
Or => And,
}
}
}
impl Display for BinaryOperator {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Plus => write!(f, "+"),
Self::Minus => write!(f, "-"),
Self::Multiply => write!(f, "*"),
Self::Divide => write!(f, "/"),
Self::LessThan => write!(f, "<"),
Self::LessThanOrEqual => write!(f, "<="),
Self::GreaterThan => write!(f, ">"),
Self::GreaterThanOrEqual => write!(f, ">="),
Self::Equal => write!(f, "="),
Self::NotEqual => write!(f, "!="),
Self::Distinct => write!(f, "DISTINCT"),
Self::In => write!(f, "IN"),
Self::NotIn => write!(f, "NOT IN"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum UnaryOperator {
Not,
IsNull,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Expression {
Literal(Scalar),
Column(String),
Struct(Vec<Expression>),
BinaryOperation {
op: BinaryOperator,
left: Box<Expression>,
right: Box<Expression>,
},
UnaryOperation {
op: UnaryOperator,
expr: Box<Expression>,
},
VariadicOperation {
op: VariadicOperator,
exprs: Vec<Expression>,
},
}
impl<T: Into<Scalar>> From<T> for Expression {
fn from(value: T) -> Self {
Self::literal(value)
}
}
impl Display for Expression {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Literal(l) => write!(f, "{}", l),
Self::Column(name) => write!(f, "Column({})", name),
Self::Struct(exprs) => write!(
f,
"Struct({})",
&exprs.iter().map(|e| format!("{e}")).join(", ")
),
Self::BinaryOperation {
op: BinaryOperator::Distinct,
left,
right,
} => write!(f, "DISTINCT({}, {})", left, right),
Self::BinaryOperation { op, left, right } => write!(f, "{} {} {}", left, op, right),
Self::UnaryOperation { op, expr } => match op {
UnaryOperator::Not => write!(f, "NOT {}", expr),
UnaryOperator::IsNull => write!(f, "{} IS NULL", expr),
},
Self::VariadicOperation { op, exprs } => match op {
VariadicOperator::And => {
write!(
f,
"AND({})",
&exprs.iter().map(|e| format!("{e}")).join(", ")
)
}
VariadicOperator::Or => {
write!(
f,
"OR({})",
&exprs.iter().map(|e| format!("{e}")).join(", ")
)
}
},
}
}
}
impl Expression {
pub fn references(&self) -> HashSet<&str> {
let mut set = HashSet::new();
for expr in self.walk() {
if let Self::Column(name) = expr {
set.insert(name.as_str());
}
}
set
}
pub fn column(name: impl ToString) -> Self {
Self::Column(name.to_string())
}
pub fn literal(value: impl Into<Scalar>) -> Self {
Self::Literal(value.into())
}
pub fn struct_expr(exprs: impl IntoIterator<Item = Self>) -> Self {
Self::Struct(exprs.into_iter().collect())
}
pub fn unary(op: UnaryOperator, expr: impl Into<Expression>) -> Self {
Self::UnaryOperation {
op,
expr: Box::new(expr.into()),
}
}
pub fn binary(
op: BinaryOperator,
lhs: impl Into<Expression>,
rhs: impl Into<Expression>,
) -> Self {
Self::BinaryOperation {
op,
left: Box::new(lhs.into()),
right: Box::new(rhs.into()),
}
}
pub fn variadic(op: VariadicOperator, exprs: impl IntoIterator<Item = Self>) -> Self {
let exprs = exprs.into_iter().collect::<Vec<_>>();
Self::VariadicOperation { op, exprs }
}
pub fn and_from(exprs: impl IntoIterator<Item = Self>) -> Self {
Self::variadic(VariadicOperator::And, exprs)
}
pub fn or_from(exprs: impl IntoIterator<Item = Self>) -> Self {
Self::variadic(VariadicOperator::Or, exprs)
}
pub fn is_null(self) -> Self {
Self::unary(UnaryOperator::IsNull, self)
}
pub fn eq(self, other: Self) -> Self {
Self::binary(BinaryOperator::Equal, self, other)
}
pub fn ne(self, other: Self) -> Self {
Self::binary(BinaryOperator::NotEqual, self, other)
}
pub fn le(self, other: Self) -> Self {
Self::binary(BinaryOperator::LessThanOrEqual, self, other)
}
pub fn lt(self, other: Self) -> Self {
Self::binary(BinaryOperator::LessThan, self, other)
}
pub fn ge(self, other: Self) -> Self {
Self::binary(BinaryOperator::GreaterThanOrEqual, self, other)
}
pub fn gt(self, other: Self) -> Self {
Self::binary(BinaryOperator::GreaterThan, self, other)
}
pub fn gt_eq(self, other: Self) -> Self {
Self::binary(BinaryOperator::GreaterThanOrEqual, self, other)
}
pub fn lt_eq(self, other: Self) -> Self {
Self::binary(BinaryOperator::LessThanOrEqual, self, other)
}
pub fn and(self, other: Self) -> Self {
Self::and_from([self, other])
}
pub fn or(self, other: Self) -> Self {
Self::or_from([self, other])
}
pub fn distinct(self, other: Self) -> Self {
Self::binary(BinaryOperator::Distinct, self, other)
}
fn walk(&self) -> impl Iterator<Item = &Self> + '_ {
let mut stack = vec![self];
std::iter::from_fn(move || {
let expr = stack.pop()?;
match expr {
Self::Literal(_) => {}
Self::Column { .. } => {}
Self::Struct(exprs) => {
stack.extend(exprs.iter());
}
Self::BinaryOperation { left, right, .. } => {
stack.push(left);
stack.push(right);
}
Self::UnaryOperation { expr, .. } => {
stack.push(expr);
}
Self::VariadicOperation { exprs, .. } => {
stack.extend(exprs.iter());
}
}
Some(expr)
})
}
}
impl std::ops::Not for Expression {
type Output = Self;
fn not(self) -> Self {
Self::unary(UnaryOperator::Not, self)
}
}
impl std::ops::Add<Expression> for Expression {
type Output = Self;
fn add(self, rhs: Expression) -> Self::Output {
Self::binary(BinaryOperator::Plus, self, rhs)
}
}
impl std::ops::Sub<Expression> for Expression {
type Output = Self;
fn sub(self, rhs: Expression) -> Self {
Self::binary(BinaryOperator::Minus, self, rhs)
}
}
impl std::ops::Mul<Expression> for Expression {
type Output = Self;
fn mul(self, rhs: Expression) -> Self {
Self::binary(BinaryOperator::Multiply, self, rhs)
}
}
impl std::ops::Div<Expression> for Expression {
type Output = Self;
fn div(self, rhs: Expression) -> Self {
Self::binary(BinaryOperator::Divide, self, rhs)
}
}
#[cfg(test)]
mod tests {
use super::Expression as Expr;
#[test]
fn test_expression_format() {
let col_ref = Expr::column("x");
let cases = [
(col_ref.clone(), "Column(x)"),
(col_ref.clone().eq(Expr::literal(2)), "Column(x) = 2"),
(
(col_ref.clone() - Expr::literal(4)).lt(Expr::literal(10)),
"Column(x) - 4 < 10",
),
(
(col_ref.clone() + Expr::literal(4)) / Expr::literal(10) * Expr::literal(42),
"Column(x) + 4 / 10 * 42",
),
(
col_ref
.clone()
.gt_eq(Expr::literal(2))
.and(col_ref.clone().lt_eq(Expr::literal(10))),
"AND(Column(x) >= 2, Column(x) <= 10)",
),
(
Expr::and_from([
col_ref.clone().gt_eq(Expr::literal(2)),
col_ref.clone().lt_eq(Expr::literal(10)),
col_ref.clone().lt_eq(Expr::literal(100)),
]),
"AND(Column(x) >= 2, Column(x) <= 10, Column(x) <= 100)",
),
(
col_ref
.clone()
.gt(Expr::literal(2))
.or(col_ref.clone().lt(Expr::literal(10))),
"OR(Column(x) > 2, Column(x) < 10)",
),
(col_ref.eq(Expr::literal("foo")), "Column(x) = 'foo'"),
];
for (expr, expected) in cases {
let result = format!("{}", expr);
assert_eq!(result, expected);
}
}
}