use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::fmt::{Display, Formatter};
use std::sync::Arc;
use itertools::Itertools;
use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
pub use self::column_names::{
column_expr, column_expr_ref, column_name, column_pred, joined_column_expr, joined_column_name,
ColumnName,
};
pub use self::scalars::{ArrayData, DecimalData, MapData, Scalar, StructData};
use crate::kernel_predicates::{
DirectDataSkippingPredicateEvaluator, DirectPredicateEvaluator,
IndirectDataSkippingPredicateEvaluator,
};
use crate::schema::SchemaRef;
use crate::transforms::ExpressionTransform;
use crate::{DataType, DeltaResult, DynPartialEq};
mod column_names;
pub(crate) mod literal_expression_transform;
pub(crate) use literal_expression_transform::literal_expression_transform;
mod scalars;
pub type ExpressionRef = std::sync::Arc<Expression>;
pub type PredicateRef = std::sync::Arc<Predicate>;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum UnaryPredicateOp {
IsNull,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BinaryPredicateOp {
LessThan,
GreaterThan,
Equal,
Distinct,
In,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum UnaryExpressionOp {
ToJson,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BinaryExpressionOp {
Plus,
Minus,
Multiply,
Divide,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum VariadicExpressionOp {
Coalesce,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum JunctionPredicateOp {
And,
Or,
}
pub type ScalarExpressionEvaluator<'a> = dyn Fn(&Expression) -> Option<Scalar> + 'a;
pub trait OpaqueExpressionOp: DynPartialEq + std::fmt::Debug {
fn name(&self) -> &str;
fn eval_expr_scalar(
&self,
eval_expr: &ScalarExpressionEvaluator<'_>,
exprs: &[Expression],
) -> DeltaResult<Scalar>;
}
pub trait OpaquePredicateOp: DynPartialEq + std::fmt::Debug {
fn name(&self) -> &str;
fn eval_pred_scalar(
&self,
eval_expr: &ScalarExpressionEvaluator<'_>,
eval_pred: &DirectPredicateEvaluator<'_>,
exprs: &[Expression],
inverted: bool,
) -> DeltaResult<Option<bool>>;
fn eval_as_data_skipping_predicate(
&self,
evaluator: &DirectDataSkippingPredicateEvaluator<'_>,
exprs: &[Expression],
inverted: bool,
) -> Option<bool>;
fn as_data_skipping_predicate(
&self,
evaluator: &IndirectDataSkippingPredicateEvaluator<'_>,
exprs: &[Expression],
inverted: bool,
) -> Option<Predicate>;
}
pub type OpaqueExpressionOpRef = Arc<dyn OpaqueExpressionOp>;
pub type OpaquePredicateOpRef = Arc<dyn OpaquePredicateOp>;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct UnaryPredicate {
pub op: UnaryPredicateOp,
pub expr: Box<Expression>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct BinaryPredicate {
pub op: BinaryPredicateOp,
pub left: Box<Expression>,
pub right: Box<Expression>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct UnaryExpression {
pub op: UnaryExpressionOp,
pub expr: Box<Expression>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct BinaryExpression {
pub op: BinaryExpressionOp,
pub left: Box<Expression>,
pub right: Box<Expression>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct VariadicExpression {
pub op: VariadicExpressionOp,
pub exprs: Vec<Expression>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ParseJsonExpression {
pub json_expr: Box<Expression>,
pub output_schema: SchemaRef,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct JunctionPredicate {
pub op: JunctionPredicateOp,
pub preds: Vec<Predicate>,
}
#[derive(Clone, Debug)]
pub struct OpaquePredicate {
pub op: OpaquePredicateOpRef,
pub exprs: Vec<Expression>,
}
fn fail_serialize_opaque_predicate<S>(
_value: &OpaquePredicate,
_serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Err(ser::Error::custom("Cannot serialize an Opaque Predicate"))
}
fn fail_deserialize_opaque_predicate<'de, D>(_deserializer: D) -> Result<OpaquePredicate, D::Error>
where
D: Deserializer<'de>,
{
Err(de::Error::custom("Cannot deserialize an Opaque Predicate"))
}
impl OpaquePredicate {
pub(crate) fn new(
op: OpaquePredicateOpRef,
exprs: impl IntoIterator<Item = Expression>,
) -> Self {
let exprs = exprs.into_iter().collect();
Self { op, exprs }
}
}
#[derive(Clone, Debug)]
pub struct OpaqueExpression {
pub op: OpaqueExpressionOpRef,
pub exprs: Vec<Expression>,
}
impl OpaqueExpression {
pub(crate) fn new(
op: OpaqueExpressionOpRef,
exprs: impl IntoIterator<Item = Expression>,
) -> Self {
let exprs = exprs.into_iter().collect();
Self { op, exprs }
}
}
fn fail_serialize_opaque_expression<S>(
_value: &OpaqueExpression,
_serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Err(ser::Error::custom("Cannot serialize an Opaque Expression"))
}
fn fail_deserialize_opaque_expression<'de, D>(
_deserializer: D,
) -> Result<OpaqueExpression, D::Error>
where
D: Deserializer<'de>,
{
Err(de::Error::custom("Cannot deserialize an Opaque Expression"))
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct FieldTransform {
pub exprs: Vec<ExpressionRef>,
pub is_replace: bool,
pub optional: bool,
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct Transform {
pub input_path: Option<ColumnName>,
pub field_transforms: HashMap<String, FieldTransform>,
pub prepended_fields: Vec<ExpressionRef>,
}
impl Transform {
pub fn new_top_level() -> Self {
Self::default()
}
pub fn new_nested<A>(path: impl IntoIterator<Item = A>) -> Self
where
ColumnName: FromIterator<A>,
{
Self {
input_path: Some(ColumnName::new(path)),
..Default::default()
}
}
pub fn with_dropped_field(mut self, name: impl Into<String>) -> Self {
let field_transform = self.field_transform(name);
field_transform.is_replace = true;
self
}
pub fn with_dropped_field_if_exists(mut self, name: impl Into<String>) -> Self {
let field_transform = self.field_transform(name);
field_transform.is_replace = true;
field_transform.optional = true;
self
}
pub fn with_replaced_field(mut self, name: impl Into<String>, expr: ExpressionRef) -> Self {
let field_transform = self.field_transform(name);
field_transform.exprs.push(expr);
field_transform.is_replace = true;
self
}
pub fn with_inserted_field(
mut self,
after: Option<impl Into<String>>,
expr: ExpressionRef,
) -> Self {
match after {
Some(field_name) => self.field_transform(field_name).exprs.push(expr),
None => self.prepended_fields.push(expr),
}
self
}
pub fn is_identity(&self) -> bool {
self.prepended_fields.is_empty() && self.field_transforms.is_empty()
}
pub fn input_path(&self) -> Option<&ColumnName> {
self.input_path.as_ref()
}
fn field_transform(&mut self, field_name: impl Into<String>) -> &mut FieldTransform {
self.field_transforms.entry(field_name.into()).or_default()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Expression {
Literal(Scalar),
Column(ColumnName),
Predicate(Box<Predicate>), Struct(Vec<ExpressionRef>, Option<ExpressionRef>),
Transform(Transform),
Unary(UnaryExpression),
Binary(BinaryExpression),
Variadic(VariadicExpression),
#[serde(serialize_with = "fail_serialize_opaque_expression")]
#[serde(deserialize_with = "fail_deserialize_opaque_expression")]
Opaque(OpaqueExpression),
Unknown(String),
ParseJson(ParseJsonExpression),
MapToStruct(MapToStructExpression),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Predicate {
BooleanExpression(Expression),
Not(Box<Predicate>),
Unary(UnaryPredicate),
Binary(BinaryPredicate),
Junction(JunctionPredicate),
#[serde(serialize_with = "fail_serialize_opaque_predicate")]
#[serde(deserialize_with = "fail_deserialize_opaque_predicate")]
Opaque(OpaquePredicate),
Unknown(String),
}
impl BinaryPredicateOp {
pub(crate) fn is_null_intolerant(&self) -> bool {
use BinaryPredicateOp::*;
match self {
LessThan | GreaterThan | Equal => true,
Distinct | In => false, }
}
}
impl JunctionPredicateOp {
pub(crate) fn invert(&self) -> JunctionPredicateOp {
use JunctionPredicateOp::*;
match self {
And => Or,
Or => And,
}
}
}
impl UnaryExpression {
pub(crate) fn new(op: UnaryExpressionOp, expr: impl Into<Expression>) -> Self {
let expr = Box::new(expr.into());
Self { op, expr }
}
}
impl UnaryPredicate {
pub(crate) fn new(op: UnaryPredicateOp, expr: impl Into<Expression>) -> Self {
let expr = Box::new(expr.into());
Self { op, expr }
}
}
impl BinaryExpression {
pub(crate) fn new(
op: BinaryExpressionOp,
left: impl Into<Expression>,
right: impl Into<Expression>,
) -> Self {
let left = Box::new(left.into());
let right = Box::new(right.into());
Self { op, left, right }
}
}
impl BinaryPredicate {
pub(crate) fn new(
op: BinaryPredicateOp,
left: impl Into<Expression>,
right: impl Into<Expression>,
) -> Self {
let left = Box::new(left.into());
let right = Box::new(right.into());
Self { op, left, right }
}
}
impl VariadicExpression {
pub(crate) fn new(
op: VariadicExpressionOp,
exprs: impl IntoIterator<Item = impl Into<Expression>>,
) -> Self {
let exprs = exprs.into_iter().map(Into::into).collect();
Self { op, exprs }
}
}
impl ParseJsonExpression {
pub(crate) fn new(json_expr: impl Into<Expression>, output_schema: SchemaRef) -> Self {
Self {
json_expr: Box::new(json_expr.into()),
output_schema,
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct MapToStructExpression {
pub map_expr: Box<Expression>,
}
impl MapToStructExpression {
pub(crate) fn new(map_expr: impl Into<Expression>) -> Self {
Self {
map_expr: Box::new(map_expr.into()),
}
}
}
impl JunctionPredicate {
pub(crate) fn new(op: JunctionPredicateOp, preds: Vec<Predicate>) -> Self {
Self { op, preds }
}
}
impl Expression {
pub fn references(&self) -> HashSet<&ColumnName> {
let mut references = GetColumnReferences::default();
let _ = references.transform_expr(self);
references.0
}
pub fn column<A>(field_names: impl IntoIterator<Item = A>) -> Expression
where
ColumnName: FromIterator<A>,
{
ColumnName::new(field_names).into()
}
pub fn literal(value: impl Into<Scalar>) -> Self {
Self::Literal(value.into())
}
pub const fn null_literal(data_type: DataType) -> Self {
Self::Literal(Scalar::Null(data_type))
}
pub fn from_pred(value: Predicate) -> Self {
match value {
Predicate::BooleanExpression(expr) => expr,
_ => Self::Predicate(Box::new(value)),
}
}
pub fn struct_from(exprs: impl IntoIterator<Item = impl Into<Arc<Self>>>) -> Self {
Self::Struct(exprs.into_iter().map(Into::into).collect(), None)
}
pub fn struct_with_nullability_from(
exprs: impl IntoIterator<Item = impl Into<Arc<Self>>>,
nullability_predicate: impl Into<Arc<Self>>,
) -> Self {
Self::Struct(
exprs.into_iter().map(Into::into).collect(),
Some(nullability_predicate.into()),
)
}
pub fn transform(transform: Transform) -> Self {
Self::Transform(transform)
}
pub fn is_null(self) -> Predicate {
Predicate::is_null(self)
}
pub fn is_not_null(self) -> Predicate {
Predicate::is_not_null(self)
}
pub fn eq(self, other: impl Into<Self>) -> Predicate {
Predicate::eq(self, other)
}
pub fn ne(self, other: impl Into<Self>) -> Predicate {
Predicate::ne(self, other)
}
pub fn le(self, other: impl Into<Self>) -> Predicate {
Predicate::le(self, other)
}
pub fn lt(self, other: impl Into<Self>) -> Predicate {
Predicate::lt(self, other)
}
pub fn ge(self, other: impl Into<Self>) -> Predicate {
Predicate::ge(self, other)
}
pub fn gt(self, other: impl Into<Self>) -> Predicate {
Predicate::gt(self, other)
}
pub fn distinct(self, other: impl Into<Self>) -> Predicate {
Predicate::distinct(self, other)
}
pub fn unary(op: UnaryExpressionOp, expr: impl Into<Expression>) -> Self {
Self::Unary(UnaryExpression::new(op, expr))
}
pub fn binary(
op: BinaryExpressionOp,
lhs: impl Into<Expression>,
rhs: impl Into<Expression>,
) -> Self {
Self::Binary(BinaryExpression::new(op, lhs, rhs))
}
pub fn variadic(
op: VariadicExpressionOp,
exprs: impl IntoIterator<Item = impl Into<Expression>>,
) -> Self {
Self::Variadic(VariadicExpression::new(op, exprs))
}
pub fn coalesce(exprs: impl IntoIterator<Item = impl Into<Expression>>) -> Self {
Self::variadic(VariadicExpressionOp::Coalesce, exprs)
}
pub fn opaque(
op: impl OpaqueExpressionOp,
exprs: impl IntoIterator<Item = Expression>,
) -> Self {
Self::Opaque(OpaqueExpression::new(Arc::new(op), exprs))
}
pub fn unknown(name: impl Into<String>) -> Self {
Self::Unknown(name.into())
}
pub fn parse_json(json_expr: impl Into<Expression>, output_schema: SchemaRef) -> Self {
Self::ParseJson(ParseJsonExpression::new(json_expr, output_schema))
}
pub fn map_to_struct(map_expr: impl Into<Expression>) -> Self {
Self::MapToStruct(MapToStructExpression::new(map_expr))
}
}
impl Predicate {
pub fn references(&self) -> HashSet<&ColumnName> {
let mut references = GetColumnReferences::default();
let _ = references.transform_pred(self);
references.0
}
pub fn column<A>(field_names: impl IntoIterator<Item = A>) -> Predicate
where
ColumnName: FromIterator<A>,
{
Self::from_expr(ColumnName::new(field_names))
}
pub const fn literal(value: bool) -> Self {
Self::BooleanExpression(Expression::Literal(Scalar::Boolean(value)))
}
pub const fn null_literal() -> Self {
Self::BooleanExpression(Expression::Literal(Scalar::Null(DataType::BOOLEAN)))
}
pub fn from_expr(expr: impl Into<Expression>) -> Self {
match expr.into() {
Expression::Predicate(p) => *p,
expr => Predicate::BooleanExpression(expr),
}
}
pub fn not(pred: impl Into<Self>) -> Self {
Self::Not(Box::new(pred.into()))
}
pub fn is_null(expr: impl Into<Expression>) -> Predicate {
Self::unary(UnaryPredicateOp::IsNull, expr)
}
pub fn is_not_null(expr: impl Into<Expression>) -> Predicate {
Self::not(Self::is_null(expr))
}
pub fn eq(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::binary(BinaryPredicateOp::Equal, a, b)
}
pub fn ne(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::not(Self::binary(BinaryPredicateOp::Equal, a, b))
}
pub fn le(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::not(Self::binary(BinaryPredicateOp::GreaterThan, a, b))
}
pub fn lt(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::binary(BinaryPredicateOp::LessThan, a, b)
}
pub fn ge(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::not(Self::binary(BinaryPredicateOp::LessThan, a, b))
}
pub fn gt(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::binary(BinaryPredicateOp::GreaterThan, a, b)
}
pub fn distinct(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::binary(BinaryPredicateOp::Distinct, a, b)
}
pub fn and(a: impl Into<Self>, b: impl Into<Self>) -> Self {
Self::and_from([a.into(), b.into()])
}
pub fn or(a: impl Into<Self>, b: impl Into<Self>) -> Self {
Self::or_from([a.into(), b.into()])
}
pub fn and_from(preds: impl IntoIterator<Item = Self>) -> Self {
Self::junction(JunctionPredicateOp::And, preds)
}
pub fn or_from(preds: impl IntoIterator<Item = Self>) -> Self {
Self::junction(JunctionPredicateOp::Or, preds)
}
pub fn unary(op: UnaryPredicateOp, expr: impl Into<Expression>) -> Self {
let expr = Box::new(expr.into());
Self::Unary(UnaryPredicate { op, expr })
}
pub fn binary(
op: BinaryPredicateOp,
lhs: impl Into<Expression>,
rhs: impl Into<Expression>,
) -> Self {
Self::Binary(BinaryPredicate {
op,
left: Box::new(lhs.into()),
right: Box::new(rhs.into()),
})
}
pub fn junction(op: JunctionPredicateOp, preds: impl IntoIterator<Item = Self>) -> Self {
let mut preds: Vec<_> = preds.into_iter().collect();
match preds.len() {
0 => match op {
JunctionPredicateOp::And => Self::literal(true),
JunctionPredicateOp::Or => Self::literal(false),
},
1 => preds.remove(0),
_ => Self::Junction(JunctionPredicate { op, preds }),
}
}
pub fn opaque(op: impl OpaquePredicateOp, exprs: impl IntoIterator<Item = Expression>) -> Self {
Self::Opaque(OpaquePredicate::new(Arc::new(op), exprs))
}
pub fn unknown(name: impl Into<String>) -> Self {
Self::Unknown(name.into())
}
}
impl PartialEq for OpaquePredicate {
fn eq(&self, other: &Self) -> bool {
self.op.dyn_eq(other.op.any_ref()) && self.exprs == other.exprs
}
}
impl PartialEq for OpaqueExpression {
fn eq(&self, other: &Self) -> bool {
self.op.dyn_eq(other.op.any_ref()) && self.exprs == other.exprs
}
}
impl Display for UnaryExpressionOp {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use UnaryExpressionOp::*;
match self {
ToJson => write!(f, "TO_JSON"),
}
}
}
impl Display for BinaryExpressionOp {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use BinaryExpressionOp::*;
match self {
Plus => write!(f, "+"),
Minus => write!(f, "-"),
Multiply => write!(f, "*"),
Divide => write!(f, "/"),
}
}
}
impl Display for VariadicExpressionOp {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use VariadicExpressionOp::*;
match self {
Coalesce => write!(f, "COALESCE"),
}
}
}
impl Display for BinaryPredicateOp {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use BinaryPredicateOp::*;
match self {
LessThan => write!(f, "<"),
GreaterThan => write!(f, ">"),
Equal => write!(f, "="),
Distinct => write!(f, "DISTINCT"),
In => write!(f, "IN"),
}
}
}
fn format_child_list<T: Display>(children: &[T]) -> String {
children.iter().map(|c| format!("{c}")).join(", ")
}
impl Display for Expression {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use Expression::*;
match self {
Literal(l) => write!(f, "{l}"),
Column(name) => write!(f, "Column({name})"),
Predicate(p) => write!(f, "{p}"),
Struct(exprs, _) => write!(f, "Struct({})", format_child_list(exprs)),
Transform(transform) => {
write!(f, "Transform(")?;
let mut sep = "";
if !transform.prepended_fields.is_empty() {
let prepended_fields = format_child_list(&transform.prepended_fields);
write!(f, "prepend [{prepended_fields}]")?;
sep = ", ";
}
for (field_name, field_transform) in &transform.field_transforms {
let insertions = &field_transform.exprs;
if insertions.is_empty() {
if field_transform.is_replace {
write!(f, "{sep}drop {field_name}")?;
} else {
continue; }
} else {
let insertions = format_child_list(insertions);
if field_transform.is_replace {
write!(f, "{sep}replace {field_name} with [{insertions}]")?;
} else {
write!(f, "{sep}after {field_name} insert [{insertions}]")?;
}
}
sep = ", ";
}
write!(f, ")")
}
Unary(UnaryExpression { op, expr }) => write!(f, "{op}({expr})"),
Binary(BinaryExpression { op, left, right }) => write!(f, "{left} {op} {right}"),
Variadic(VariadicExpression { op, exprs }) => {
write!(f, "{op}({})", format_child_list(exprs))
}
Opaque(OpaqueExpression { op, exprs }) => {
write!(f, "{op:?}({})", format_child_list(exprs))
}
Unknown(name) => write!(f, "<unknown: {name}>"),
ParseJson(p) => {
write!(
f,
"PARSE_JSON({}, <schema:{} fields>)",
p.json_expr,
p.output_schema.fields().len()
)
}
MapToStruct(m) => write!(f, "MAP_TO_STRUCT({})", m.map_expr),
}
}
}
impl Display for Predicate {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use Predicate::*;
match self {
BooleanExpression(expr) => write!(f, "{expr}"),
Not(pred) => write!(f, "NOT({pred})"),
Binary(BinaryPredicate {
op: BinaryPredicateOp::Distinct,
left,
right,
}) => write!(f, "DISTINCT({left}, {right})"),
Binary(BinaryPredicate { op, left, right }) => write!(f, "{left} {op} {right}"),
Unary(UnaryPredicate { op, expr }) => match op {
UnaryPredicateOp::IsNull => write!(f, "{expr} IS NULL"),
},
Junction(JunctionPredicate { op, preds }) => {
let op = match op {
JunctionPredicateOp::And => "AND",
JunctionPredicateOp::Or => "OR",
};
write!(f, "{op}({})", format_child_list(preds))
}
Opaque(OpaquePredicate { op, exprs }) => {
write!(f, "{op:?}({})", format_child_list(exprs))
}
Unknown(name) => write!(f, "<unknown: {name}>"),
}
}
}
impl From<Scalar> for Expression {
fn from(value: Scalar) -> Self {
Self::literal(value)
}
}
impl From<ColumnName> for Expression {
fn from(value: ColumnName) -> Self {
Self::Column(value)
}
}
impl From<Predicate> for Expression {
fn from(value: Predicate) -> Self {
Self::from_pred(value)
}
}
impl From<ColumnName> for Predicate {
fn from(value: ColumnName) -> Self {
Self::from_expr(value)
}
}
impl<R: Into<Expression>> std::ops::Add<R> for Expression {
type Output = Self;
fn add(self, rhs: R) -> Self::Output {
Self::binary(BinaryExpressionOp::Plus, self, rhs)
}
}
impl<R: Into<Expression>> std::ops::Sub<R> for Expression {
type Output = Self;
fn sub(self, rhs: R) -> Self {
Self::binary(BinaryExpressionOp::Minus, self, rhs)
}
}
impl<R: Into<Expression>> std::ops::Mul<R> for Expression {
type Output = Self;
fn mul(self, rhs: R) -> Self {
Self::binary(BinaryExpressionOp::Multiply, self, rhs)
}
}
impl<R: Into<Expression>> std::ops::Div<R> for Expression {
type Output = Self;
fn div(self, rhs: R) -> Self {
Self::binary(BinaryExpressionOp::Divide, self, rhs)
}
}
#[derive(Default)]
struct GetColumnReferences<'a>(HashSet<&'a ColumnName>);
impl<'a> ExpressionTransform<'a> for GetColumnReferences<'a> {
fn transform_expr_column(&mut self, name: &'a ColumnName) -> Option<Cow<'a, ColumnName>> {
self.0.insert(name);
Some(Cow::Borrowed(name))
}
}
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use serde::de::DeserializeOwned;
use serde::Serialize;
use super::{column_expr, column_pred, Expression as Expr, Predicate as Pred};
fn assert_roundtrip<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: &T) {
let json = serde_json::to_string(value).expect("serialization should succeed");
let deserialized: T = serde_json::from_str(&json).expect("deserialization should succeed");
assert_eq!(value, &deserialized, "roundtrip should preserve value");
}
#[test]
fn test_expression_format() {
let cases = [
(column_expr!("x"), "Column(x)"),
(
(column_expr!("x") + Expr::literal(4)) / Expr::literal(10) * Expr::literal(42),
"Column(x) + 4 / 10 * 42",
),
(
Expr::struct_from([column_expr!("x"), Expr::literal(2), Expr::literal(10)]),
"Struct(Column(x), 2, 10)",
),
];
for (expr, expected) in cases {
let result = format!("{expr}");
assert_eq!(result, expected);
}
}
#[test]
fn test_predicate_format() {
let cases = [
(column_pred!("x"), "Column(x)"),
(column_expr!("x").eq(Expr::literal(2)), "Column(x) = 2"),
(
(column_expr!("x") - Expr::literal(4)).lt(Expr::literal(10)),
"Column(x) - 4 < 10",
),
(
Pred::and(
column_expr!("x").ge(Expr::literal(2)),
column_expr!("x").le(Expr::literal(10)),
),
"AND(NOT(Column(x) < 2), NOT(Column(x) > 10))",
),
(
Pred::and_from([
column_expr!("x").ge(Expr::literal(2)),
column_expr!("x").le(Expr::literal(10)),
column_expr!("x").le(Expr::literal(100)),
]),
"AND(NOT(Column(x) < 2), NOT(Column(x) > 10), NOT(Column(x) > 100))",
),
(
Pred::or(
column_expr!("x").gt(Expr::literal(2)),
column_expr!("x").lt(Expr::literal(10)),
),
"OR(Column(x) > 2, Column(x) < 10)",
),
(
column_expr!("x").eq(Expr::literal("foo")),
"Column(x) = 'foo'",
),
];
for (pred, expected) in cases {
let result = format!("{pred}");
assert_eq!(result, expected);
}
}
mod serde_tests {
use std::sync::Arc;
use crate::expressions::scalars::{ArrayData, DecimalData, MapData, StructData};
use crate::expressions::{
column_expr, column_name, BinaryExpressionOp, BinaryPredicateOp, ColumnName,
Expression, Predicate, Scalar, Transform, UnaryExpressionOp,
};
use crate::schema::{ArrayType, DataType, DecimalType, MapType, StructField};
use crate::utils::test_utils::assert_result_error_with_message;
use super::assert_roundtrip;
#[test]
fn test_literal_scalars_roundtrip() {
let cases: Vec<Expression> = vec![
Expression::literal(42i32), Expression::literal(9999999999i64), Expression::literal(123i16), Expression::literal(42i8), Expression::literal(1.12345677_32), Expression::literal(1.12345667_64), Expression::literal("hello world"),
Expression::literal(true),
Expression::literal(false),
Expression::Literal(Scalar::Timestamp(1234567890000000)),
Expression::Literal(Scalar::TimestampNtz(1234567890000000)),
Expression::Literal(Scalar::Date(19000)),
Expression::Literal(Scalar::Binary(vec![1, 2, 3, 4, 5])),
Expression::Literal(Scalar::Decimal(
DecimalData::try_new(12345i128, DecimalType::try_new(10, 2).unwrap()).unwrap(),
)),
];
for expr in &cases {
assert_roundtrip(expr);
}
}
#[test]
fn test_literal_complex_scalars_roundtrip() {
let cases: Vec<Expression> = vec![
Expression::null_literal(DataType::INTEGER),
Expression::null_literal(DataType::STRING),
Expression::null_literal(DataType::BOOLEAN),
Expression::Literal(Scalar::Array(
ArrayData::try_new(
ArrayType::new(DataType::INTEGER, false),
vec![Scalar::Integer(1), Scalar::Integer(2), Scalar::Integer(3)],
)
.unwrap(),
)),
Expression::Literal(Scalar::Map(
MapData::try_new(
MapType::new(DataType::STRING, DataType::INTEGER, false),
vec![
(Scalar::String("a".to_string()), Scalar::Integer(1)),
(Scalar::String("b".to_string()), Scalar::Integer(2)),
],
)
.unwrap(),
)),
Expression::Literal(Scalar::Struct(
StructData::try_new(
vec![
StructField::nullable("x", DataType::INTEGER),
StructField::nullable("y", DataType::STRING),
],
vec![Scalar::Integer(42), Scalar::String("hello".to_string())],
)
.unwrap(),
)),
];
for expr in &cases {
assert_roundtrip(expr);
}
}
#[test]
fn test_column_expressions_roundtrip() {
let cases: Vec<Expression> = vec![
column_expr!("my_column"),
Expression::column(["parent", "child"]),
Expression::column(["a", "b", "c", "d"]),
];
for expr in &cases {
assert_roundtrip(expr);
}
}
#[test]
fn test_column_names_roundtrip() {
let cases: Vec<ColumnName> = vec![
column_name!("simple"),
ColumnName::new(["a", "b", "c"]),
ColumnName::new::<&str>([]),
];
for col in &cases {
assert_roundtrip(col);
}
}
#[test]
fn test_unary_expression_roundtrip() {
let expr = Expression::unary(UnaryExpressionOp::ToJson, column_expr!("data"));
assert_roundtrip(&expr);
}
#[test]
fn test_binary_expressions_roundtrip() {
let ops = [
BinaryExpressionOp::Plus,
BinaryExpressionOp::Minus,
BinaryExpressionOp::Multiply,
BinaryExpressionOp::Divide,
];
for op in ops {
let expr = Expression::binary(op, column_expr!("a"), Expression::literal(10));
assert_roundtrip(&expr);
}
}
#[test]
fn test_variadic_expression_roundtrip() {
let expr = Expression::coalesce([
column_expr!("a"),
column_expr!("b"),
Expression::literal("default"),
]);
assert_roundtrip(&expr);
}
#[test]
fn test_nested_arithmetic_expression_roundtrip() {
let left = Expression::binary(
BinaryExpressionOp::Plus,
column_expr!("a"),
column_expr!("b"),
);
let right = Expression::binary(
BinaryExpressionOp::Minus,
column_expr!("c"),
column_expr!("d"),
);
let mul = Expression::binary(BinaryExpressionOp::Multiply, left, right);
let expr = Expression::binary(BinaryExpressionOp::Divide, mul, Expression::literal(2));
assert_roundtrip(&expr);
}
#[test]
fn test_struct_expression_roundtrip() {
let expr = Expression::struct_from([
Arc::new(column_expr!("x")),
Arc::new(Expression::literal(42)),
Arc::new(Expression::literal("hello")),
]);
assert_roundtrip(&expr);
}
#[test]
fn test_transform_expressions_roundtrip() {
let cases: Vec<Expression> = vec![
Expression::transform(Transform::new_top_level()),
Expression::transform(Transform::new_top_level().with_dropped_field("old_column")),
Expression::transform(
Transform::new_top_level()
.with_replaced_field("original", Arc::new(Expression::literal(0))),
),
Expression::transform(
Transform::new_top_level()
.with_inserted_field(Some("after_col"), Arc::new(column_expr!("new_col")))
.with_inserted_field(
None::<String>,
Arc::new(Expression::literal("prepended")),
),
),
Expression::transform(
Transform::new_nested(["parent", "child"]).with_dropped_field("to_drop"),
),
];
for expr in &cases {
assert_roundtrip(expr);
}
}
#[test]
fn test_expression_wrapping_predicate_roundtrip() {
let pred = Predicate::eq(column_expr!("x"), Expression::literal(10));
let expr = Expression::from_pred(pred);
assert_roundtrip(&expr);
}
#[test]
fn test_expression_unknown_roundtrip() {
let expr = Expression::unknown("some_unknown_function()");
assert_roundtrip(&expr);
}
#[test]
fn test_map_to_struct_expression_roundtrip() {
let cases: Vec<Expression> = vec![
Expression::map_to_struct(column_expr!("pv")),
Expression::map_to_struct(Expression::literal("ignored")),
];
for expr in &cases {
assert_roundtrip(expr);
}
}
#[test]
fn test_predicate_basics_roundtrip() {
let cases: Vec<Predicate> = vec![
Predicate::from_expr(column_expr!("is_active")),
Predicate::literal(true),
Predicate::literal(false),
Predicate::not(Predicate::from_expr(column_expr!("x"))),
Predicate::not(Predicate::not(Predicate::gt(
column_expr!("x"),
Expression::literal(5),
))),
Predicate::unknown("some_unknown_predicate()"),
Predicate::is_null(column_expr!("nullable_col")),
Predicate::is_not_null(column_expr!("nullable_col")),
];
for pred in &cases {
assert_roundtrip(pred);
}
}
#[test]
fn test_predicate_null_literal_roundtrip() {
let pred = Predicate::null_literal();
assert_roundtrip(&pred);
}
#[test]
fn test_predicate_comparisons_roundtrip() {
let cases: Vec<Predicate> = vec![
Predicate::eq(column_expr!("x"), Expression::literal(42)),
Predicate::ne(column_expr!("status"), Expression::literal("active")),
Predicate::lt(column_expr!("age"), Expression::literal(18)),
Predicate::le(column_expr!("price"), Expression::literal(100)),
Predicate::gt(column_expr!("score"), Expression::literal(90)),
Predicate::ge(column_expr!("quantity"), Expression::literal(1)),
Predicate::distinct(column_expr!("a"), column_expr!("b")),
];
for pred in &cases {
assert_roundtrip(pred);
}
}
#[test]
fn test_predicate_in_roundtrip() {
let array_data = ArrayData::try_new(
ArrayType::new(DataType::INTEGER, false),
vec![Scalar::Integer(1), Scalar::Integer(2), Scalar::Integer(3)],
)
.unwrap();
let pred = Predicate::binary(
BinaryPredicateOp::In,
column_expr!("x"),
Expression::Literal(Scalar::Array(array_data)),
);
assert_roundtrip(&pred);
}
#[test]
fn test_predicate_junctions_roundtrip() {
let cases: Vec<Predicate> = vec![
Predicate::and(
Predicate::gt(column_expr!("x"), Expression::literal(0)),
Predicate::lt(column_expr!("x"), Expression::literal(100)),
),
Predicate::or(
Predicate::eq(column_expr!("status"), Expression::literal("active")),
Predicate::eq(column_expr!("status"), Expression::literal("pending")),
),
Predicate::and_from([
Predicate::gt(column_expr!("x"), Expression::literal(0)),
Predicate::lt(column_expr!("x"), Expression::literal(100)),
Predicate::is_not_null(column_expr!("x")),
]),
Predicate::or_from([
Predicate::eq(column_expr!("type"), Expression::literal("A")),
Predicate::eq(column_expr!("type"), Expression::literal("B")),
Predicate::eq(column_expr!("type"), Expression::literal("C")),
]),
Predicate::or(
Predicate::and(
Predicate::gt(column_expr!("a"), Expression::literal(0)),
Predicate::lt(column_expr!("b"), Expression::literal(100)),
),
Predicate::eq(column_expr!("c"), Expression::literal("special")),
),
];
for pred in &cases {
assert_roundtrip(pred);
}
}
#[test]
fn test_deeply_nested_structures_roundtrip() {
let add = Expression::binary(
BinaryExpressionOp::Plus,
column_expr!("a"),
column_expr!("b"),
);
let mul = Expression::binary(
BinaryExpressionOp::Multiply,
column_expr!("c"),
column_expr!("d"),
);
let coalesce = Expression::coalesce([add, mul, Expression::literal(0)]);
let pred = Predicate::gt(coalesce, Expression::literal(100));
assert_roundtrip(&pred);
let inner_pred = Predicate::and(
Predicate::eq(column_expr!("x"), Expression::literal(1)),
Predicate::gt(
Expression::binary(
BinaryExpressionOp::Plus,
column_expr!("y"),
column_expr!("z"),
),
Expression::literal(10),
),
);
let expr = Expression::from_pred(inner_pred);
assert_roundtrip(&expr);
}
#[test]
fn test_opaque_expression_serialize_fails() {
use crate::expressions::{OpaqueExpressionOp, ScalarExpressionEvaluator};
use crate::DeltaResult;
#[derive(Debug, PartialEq)]
struct TestOpaqueExprOp;
impl OpaqueExpressionOp for TestOpaqueExprOp {
fn name(&self) -> &str {
"test_opaque"
}
fn eval_expr_scalar(
&self,
_eval_expr: &ScalarExpressionEvaluator<'_>,
_exprs: &[Expression],
) -> DeltaResult<Scalar> {
Ok(Scalar::Integer(0))
}
}
let expr = Expression::opaque(TestOpaqueExprOp, [Expression::literal(1)]);
let result = serde_json::to_string(&expr);
assert_result_error_with_message(result, "Cannot serialize an Opaque Expression");
}
#[test]
fn test_opaque_predicate_serialize_fails() {
use crate::expressions::{OpaquePredicateOp, ScalarExpressionEvaluator};
use crate::kernel_predicates::{
DirectDataSkippingPredicateEvaluator, DirectPredicateEvaluator,
IndirectDataSkippingPredicateEvaluator,
};
use crate::DeltaResult;
#[derive(Debug, PartialEq)]
struct TestOpaquePredOp;
impl OpaquePredicateOp for TestOpaquePredOp {
fn name(&self) -> &str {
"test_opaque_pred"
}
fn eval_pred_scalar(
&self,
_eval_expr: &ScalarExpressionEvaluator<'_>,
_eval_pred: &DirectPredicateEvaluator<'_>,
_exprs: &[Expression],
_inverted: bool,
) -> DeltaResult<Option<bool>> {
Ok(Some(true))
}
fn eval_as_data_skipping_predicate(
&self,
_evaluator: &DirectDataSkippingPredicateEvaluator<'_>,
_exprs: &[Expression],
_inverted: bool,
) -> Option<bool> {
Some(true)
}
fn as_data_skipping_predicate(
&self,
_evaluator: &IndirectDataSkippingPredicateEvaluator<'_>,
_exprs: &[Expression],
_inverted: bool,
) -> Option<Predicate> {
None
}
}
let pred = Predicate::opaque(TestOpaquePredOp, [Expression::literal(1)]);
let result = serde_json::to_string(&pred);
assert_result_error_with_message(result, "Cannot serialize an Opaque Predicate");
}
}
#[test]
fn single_element_and_from_returns_unwrapped_predicate() {
let inner = Pred::gt(column_expr!("x"), Expr::literal(0));
let result = Pred::and_from([inner.clone()]);
assert_eq!(result, inner);
}
#[test]
fn single_element_or_from_returns_unwrapped_predicate() {
let inner = Pred::gt(column_expr!("x"), Expr::literal(0));
let result = Pred::or_from([inner.clone()]);
assert_eq!(result, inner);
}
#[test]
fn multi_element_and_from_returns_junction() {
let p1 = Pred::gt(column_expr!("x"), Expr::literal(0));
let p2 = Pred::lt(column_expr!("x"), Expr::literal(100));
let result = Pred::and_from([p1.clone(), p2.clone()]);
assert!(matches!(result, Pred::Junction(ref j) if j.preds.len() == 2));
assert_eq!(result, Pred::and(p1, p2));
}
#[test]
fn empty_and_from_returns_identity_literal() {
let result = Pred::and_from(std::iter::empty());
assert_eq!(result, Pred::literal(true));
}
#[test]
fn empty_or_from_returns_identity_literal() {
let result = Pred::or_from(std::iter::empty());
assert_eq!(result, Pred::literal(false));
}
}