use std::collections::HashSet;
use std::fmt::{Display, Formatter};
use std::sync::Arc;
use itertools::Itertools;
use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
pub use self::column_names::{
col, column_expr, column_expr_ref, column_name, column_pred, joined_column_expr,
joined_column_name, ColumnName,
};
pub use self::scalars::{ArrayData, DecimalData, MapData, Scalar, StructData};
use crate::kernel_predicates::{
DirectDataSkippingPredicateEvaluator, DirectPredicateEvaluator,
IndirectDataSkippingPredicateEvaluator,
};
use crate::schema::SchemaRef;
pub use crate::struct_patch::{ExpressionFieldPatch, ExpressionStructPatch};
use crate::transforms::{transform_output_type, ExpressionTransform};
use crate::utils::CollectInto;
use crate::{DataType, DeltaResult, DynPartialEq, Error};
mod column_names;
pub(crate) mod literal_expression_transform;
pub(crate) use literal_expression_transform::literal_expression_transform;
mod scalars;
#[cfg(feature = "column-defaults-in-dev")]
mod sql;
#[cfg(feature = "column-defaults-in-dev")]
#[allow(unused_imports)]
pub(crate) use self::sql::parse_sql;
pub type ExpressionRef = std::sync::Arc<Expression>;
pub type PredicateRef = std::sync::Arc<Predicate>;
pub fn lit(value: impl Into<Scalar>) -> Expression {
Expression::literal(value)
}
pub type ExpressionStructPatchBuilder = crate::struct_patch::StructPatchBuilder<ExpressionRef>;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum UnaryPredicateOp {
IsNull,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BinaryPredicateOp {
LessThan,
GreaterThan,
Equal,
Distinct,
In,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum UnaryExpressionOp {
ToJson,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BinaryExpressionOp {
Plus,
Minus,
Multiply,
Divide,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum VariadicExpressionOp {
Coalesce,
Array,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum JunctionPredicateOp {
And,
Or,
}
pub type ScalarExpressionEvaluator<'a> = dyn Fn(&Expression) -> Option<Scalar> + 'a;
pub trait OpaqueExpressionOp: DynPartialEq + std::fmt::Debug {
fn name(&self) -> &str;
fn eval_expr_scalar(
&self,
eval_expr: &ScalarExpressionEvaluator<'_>,
exprs: &[Expression],
) -> DeltaResult<Scalar>;
}
pub trait OpaquePredicateOp: DynPartialEq + std::fmt::Debug {
fn name(&self) -> &str;
fn eval_pred_scalar(
&self,
eval_expr: &ScalarExpressionEvaluator<'_>,
eval_pred: &DirectPredicateEvaluator<'_>,
exprs: &[Expression],
inverted: bool,
) -> DeltaResult<Option<bool>>;
fn eval_as_data_skipping_predicate(
&self,
evaluator: &DirectDataSkippingPredicateEvaluator<'_>,
exprs: &[Expression],
inverted: bool,
) -> Option<bool>;
fn as_data_skipping_predicate(
&self,
evaluator: &IndirectDataSkippingPredicateEvaluator<'_>,
exprs: &[Expression],
inverted: bool,
) -> Option<Predicate>;
}
pub type OpaqueExpressionOpRef = Arc<dyn OpaqueExpressionOp>;
pub type OpaquePredicateOpRef = Arc<dyn OpaquePredicateOp>;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct UnaryPredicate {
pub op: UnaryPredicateOp,
pub expr: Box<Expression>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct BinaryPredicate {
pub op: BinaryPredicateOp,
pub left: Box<Expression>,
pub right: Box<Expression>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct UnaryExpression {
pub op: UnaryExpressionOp,
pub expr: Box<Expression>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct BinaryExpression {
pub op: BinaryExpressionOp,
pub left: Box<Expression>,
pub right: Box<Expression>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct VariadicExpression {
pub op: VariadicExpressionOp,
pub exprs: Vec<Expression>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ParseJsonExpression {
pub json_expr: Box<Expression>,
pub output_schema: SchemaRef,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct JunctionPredicate {
pub op: JunctionPredicateOp,
pub preds: Vec<Predicate>,
}
#[derive(Clone, Debug)]
pub struct OpaquePredicate {
pub op: OpaquePredicateOpRef,
pub exprs: Vec<Expression>,
}
fn fail_serialize_opaque_predicate<S>(
_value: &OpaquePredicate,
_serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Err(ser::Error::custom("Cannot serialize an Opaque Predicate"))
}
fn fail_deserialize_opaque_predicate<'de, D>(_deserializer: D) -> Result<OpaquePredicate, D::Error>
where
D: Deserializer<'de>,
{
Err(de::Error::custom("Cannot deserialize an Opaque Predicate"))
}
impl OpaquePredicate {
pub(crate) fn new(
op: OpaquePredicateOpRef,
exprs: impl IntoIterator<Item = Expression>,
) -> Self {
let exprs = exprs.into_iter().collect();
Self { op, exprs }
}
}
#[derive(Clone, Debug)]
pub struct OpaqueExpression {
pub op: OpaqueExpressionOpRef,
pub exprs: Vec<Expression>,
}
impl OpaqueExpression {
pub(crate) fn new(
op: OpaqueExpressionOpRef,
exprs: impl IntoIterator<Item = Expression>,
) -> Self {
let exprs = exprs.into_iter().collect();
Self { op, exprs }
}
}
fn fail_serialize_opaque_expression<S>(
_value: &OpaqueExpression,
_serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Err(ser::Error::custom("Cannot serialize an Opaque Expression"))
}
fn fail_deserialize_opaque_expression<'de, D>(
_deserializer: D,
) -> Result<OpaqueExpression, D::Error>
where
D: Deserializer<'de>,
{
Err(de::Error::custom("Cannot deserialize an Opaque Expression"))
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Expression {
Literal(Scalar),
Column(ColumnName),
Predicate(Box<Predicate>), Struct(Vec<ExpressionRef>, Option<ExpressionRef>),
#[serde(alias = "Transform")]
StructPatch(ExpressionStructPatch),
Unary(UnaryExpression),
Binary(BinaryExpression),
Variadic(VariadicExpression),
#[serde(serialize_with = "fail_serialize_opaque_expression")]
#[serde(deserialize_with = "fail_deserialize_opaque_expression")]
Opaque(OpaqueExpression),
Unknown(String),
ParseJson(ParseJsonExpression),
MapToStruct(MapToStructExpression),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Predicate {
BooleanExpression(Expression),
Not(Box<Predicate>),
Unary(UnaryPredicate),
Binary(BinaryPredicate),
Junction(JunctionPredicate),
#[serde(serialize_with = "fail_serialize_opaque_predicate")]
#[serde(deserialize_with = "fail_deserialize_opaque_predicate")]
Opaque(OpaquePredicate),
Unknown(String),
}
impl BinaryPredicateOp {
pub(crate) fn is_null_intolerant(&self) -> bool {
use BinaryPredicateOp::*;
match self {
LessThan | GreaterThan | Equal => true,
Distinct | In => false, }
}
}
impl JunctionPredicateOp {
pub(crate) fn invert(&self) -> JunctionPredicateOp {
use JunctionPredicateOp::*;
match self {
And => Or,
Or => And,
}
}
}
impl UnaryExpression {
pub(crate) fn new(op: UnaryExpressionOp, expr: impl Into<Expression>) -> Self {
let expr = Box::new(expr.into());
Self { op, expr }
}
}
impl UnaryPredicate {
pub(crate) fn new(op: UnaryPredicateOp, expr: impl Into<Expression>) -> Self {
let expr = Box::new(expr.into());
Self { op, expr }
}
}
impl BinaryExpression {
pub(crate) fn new(
op: BinaryExpressionOp,
left: impl Into<Expression>,
right: impl Into<Expression>,
) -> Self {
let left = Box::new(left.into());
let right = Box::new(right.into());
Self { op, left, right }
}
}
impl BinaryPredicate {
pub(crate) fn new(
op: BinaryPredicateOp,
left: impl Into<Expression>,
right: impl Into<Expression>,
) -> Self {
let left = Box::new(left.into());
let right = Box::new(right.into());
Self { op, left, right }
}
}
impl VariadicExpression {
pub(crate) fn new(
op: VariadicExpressionOp,
exprs: impl IntoIterator<Item = impl Into<Expression>>,
) -> Self {
let exprs = exprs.into_iter().map(Into::into).collect();
Self { op, exprs }
}
}
impl ParseJsonExpression {
pub(crate) fn new(json_expr: impl Into<Expression>, output_schema: SchemaRef) -> Self {
Self {
json_expr: Box::new(json_expr.into()),
output_schema,
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct MapToStructExpression {
pub map_expr: Box<Expression>,
}
impl MapToStructExpression {
pub(crate) fn new(map_expr: impl Into<Expression>) -> Self {
Self {
map_expr: Box::new(map_expr.into()),
}
}
}
impl JunctionPredicate {
pub(crate) fn new(op: JunctionPredicateOp, preds: Vec<Predicate>) -> Self {
Self { op, preds }
}
}
impl Expression {
pub fn references(&self) -> HashSet<&ColumnName> {
let mut references = GetColumnReferences::default();
references.transform_expr(self);
references.0
}
pub fn column(field_names: impl CollectInto<ColumnName>) -> Expression {
ColumnName::new(field_names).into()
}
pub fn literal(value: impl Into<Scalar>) -> Self {
Self::Literal(value.into())
}
pub const fn null_literal(data_type: DataType) -> Self {
Self::Literal(Scalar::Null(data_type))
}
pub fn from_pred(value: Predicate) -> Self {
match value {
Predicate::BooleanExpression(expr) => expr,
_ => Self::Predicate(Box::new(value)),
}
}
pub fn struct_from(exprs: impl IntoIterator<Item = impl Into<Arc<Self>>>) -> Self {
Self::Struct(exprs.into_iter().map(Into::into).collect(), None)
}
pub fn struct_with_nullability_from(
exprs: impl IntoIterator<Item = impl Into<Arc<Self>>>,
nullability_predicate: impl Into<Arc<Self>>,
) -> Self {
Self::Struct(
exprs.into_iter().map(Into::into).collect(),
Some(nullability_predicate.into()),
)
}
pub fn struct_patch<P>(patch: P) -> DeltaResult<Self>
where
P: TryInto<ExpressionStructPatch>,
Error: From<P::Error>,
{
Ok(Self::StructPatch(patch.try_into()?))
}
pub fn is_null(self) -> Predicate {
Predicate::is_null(self)
}
pub fn is_not_null(self) -> Predicate {
Predicate::is_not_null(self)
}
pub fn eq(self, other: impl Into<Self>) -> Predicate {
Predicate::eq(self, other)
}
pub fn ne(self, other: impl Into<Self>) -> Predicate {
Predicate::ne(self, other)
}
pub fn le(self, other: impl Into<Self>) -> Predicate {
Predicate::le(self, other)
}
pub fn lt(self, other: impl Into<Self>) -> Predicate {
Predicate::lt(self, other)
}
pub fn ge(self, other: impl Into<Self>) -> Predicate {
Predicate::ge(self, other)
}
pub fn gt(self, other: impl Into<Self>) -> Predicate {
Predicate::gt(self, other)
}
pub fn distinct(self, other: impl Into<Self>) -> Predicate {
Predicate::distinct(self, other)
}
pub fn unary(op: UnaryExpressionOp, expr: impl Into<Expression>) -> Self {
Self::Unary(UnaryExpression::new(op, expr))
}
pub fn binary(
op: BinaryExpressionOp,
lhs: impl Into<Expression>,
rhs: impl Into<Expression>,
) -> Self {
Self::Binary(BinaryExpression::new(op, lhs, rhs))
}
pub fn variadic(
op: VariadicExpressionOp,
exprs: impl IntoIterator<Item = impl Into<Expression>>,
) -> Self {
Self::Variadic(VariadicExpression::new(op, exprs))
}
pub fn coalesce(exprs: impl IntoIterator<Item = impl Into<Expression>>) -> Self {
Self::variadic(VariadicExpressionOp::Coalesce, exprs)
}
pub fn array(exprs: impl IntoIterator<Item = impl Into<Expression>>) -> Self {
Self::variadic(VariadicExpressionOp::Array, exprs)
}
pub fn opaque(
op: impl OpaqueExpressionOp,
exprs: impl IntoIterator<Item = Expression>,
) -> Self {
Self::Opaque(OpaqueExpression::new(Arc::new(op), exprs))
}
pub fn unknown(name: impl Into<String>) -> Self {
Self::Unknown(name.into())
}
pub fn parse_json(json_expr: impl Into<Expression>, output_schema: SchemaRef) -> Self {
Self::ParseJson(ParseJsonExpression::new(json_expr, output_schema))
}
pub fn map_to_struct(map_expr: impl Into<Expression>) -> Self {
Self::MapToStruct(MapToStructExpression::new(map_expr))
}
}
impl Predicate {
pub fn references(&self) -> HashSet<&ColumnName> {
let mut references = GetColumnReferences::default();
references.transform_pred(self);
references.0
}
pub fn column(field_names: impl CollectInto<ColumnName>) -> Predicate {
Self::from_expr(ColumnName::new(field_names))
}
pub const fn literal(value: bool) -> Self {
Self::BooleanExpression(Expression::Literal(Scalar::Boolean(value)))
}
pub const fn null_literal() -> Self {
Self::BooleanExpression(Expression::Literal(Scalar::Null(DataType::BOOLEAN)))
}
pub fn from_expr(expr: impl Into<Expression>) -> Self {
match expr.into() {
Expression::Predicate(p) => *p,
expr => Predicate::BooleanExpression(expr),
}
}
pub fn not(pred: impl Into<Self>) -> Self {
Self::Not(Box::new(pred.into()))
}
pub fn is_null(expr: impl Into<Expression>) -> Predicate {
Self::unary(UnaryPredicateOp::IsNull, expr)
}
pub fn is_not_null(expr: impl Into<Expression>) -> Predicate {
Self::not(Self::is_null(expr))
}
pub fn eq(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::binary(BinaryPredicateOp::Equal, a, b)
}
pub fn ne(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::not(Self::binary(BinaryPredicateOp::Equal, a, b))
}
pub fn le(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::not(Self::binary(BinaryPredicateOp::GreaterThan, a, b))
}
pub fn lt(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::binary(BinaryPredicateOp::LessThan, a, b)
}
pub fn ge(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::not(Self::binary(BinaryPredicateOp::LessThan, a, b))
}
pub fn gt(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::binary(BinaryPredicateOp::GreaterThan, a, b)
}
pub fn distinct(a: impl Into<Expression>, b: impl Into<Expression>) -> Self {
Self::binary(BinaryPredicateOp::Distinct, a, b)
}
pub fn and(a: impl Into<Self>, b: impl Into<Self>) -> Self {
Self::and_from([a.into(), b.into()])
}
pub fn or(a: impl Into<Self>, b: impl Into<Self>) -> Self {
Self::or_from([a.into(), b.into()])
}
pub fn and_from(preds: impl IntoIterator<Item = Self>) -> Self {
Self::junction(JunctionPredicateOp::And, preds)
}
pub fn or_from(preds: impl IntoIterator<Item = Self>) -> Self {
Self::junction(JunctionPredicateOp::Or, preds)
}
pub fn unary(op: UnaryPredicateOp, expr: impl Into<Expression>) -> Self {
let expr = Box::new(expr.into());
Self::Unary(UnaryPredicate { op, expr })
}
pub fn binary(
op: BinaryPredicateOp,
lhs: impl Into<Expression>,
rhs: impl Into<Expression>,
) -> Self {
Self::Binary(BinaryPredicate {
op,
left: Box::new(lhs.into()),
right: Box::new(rhs.into()),
})
}
pub fn junction(op: JunctionPredicateOp, preds: impl IntoIterator<Item = Self>) -> Self {
let mut preds: Vec<_> = preds.into_iter().collect();
match preds.len() {
0 => match op {
JunctionPredicateOp::And => Self::literal(true),
JunctionPredicateOp::Or => Self::literal(false),
},
1 => preds.remove(0),
_ => Self::Junction(JunctionPredicate { op, preds }),
}
}
pub fn opaque(op: impl OpaquePredicateOp, exprs: impl IntoIterator<Item = Expression>) -> Self {
Self::Opaque(OpaquePredicate::new(Arc::new(op), exprs))
}
pub fn unknown(name: impl Into<String>) -> Self {
Self::Unknown(name.into())
}
}
impl PartialEq for OpaquePredicate {
fn eq(&self, other: &Self) -> bool {
self.op.dyn_eq(other.op.any_ref()) && self.exprs == other.exprs
}
}
impl PartialEq for OpaqueExpression {
fn eq(&self, other: &Self) -> bool {
self.op.dyn_eq(other.op.any_ref()) && self.exprs == other.exprs
}
}
impl Display for UnaryExpressionOp {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use UnaryExpressionOp::*;
match self {
ToJson => write!(f, "TO_JSON"),
}
}
}
impl Display for BinaryExpressionOp {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use BinaryExpressionOp::*;
match self {
Plus => write!(f, "+"),
Minus => write!(f, "-"),
Multiply => write!(f, "*"),
Divide => write!(f, "/"),
}
}
}
impl Display for VariadicExpressionOp {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use VariadicExpressionOp::*;
match self {
Coalesce => write!(f, "COALESCE"),
Array => write!(f, "ARRAY"),
}
}
}
impl Display for BinaryPredicateOp {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use BinaryPredicateOp::*;
match self {
LessThan => write!(f, "<"),
GreaterThan => write!(f, ">"),
Equal => write!(f, "="),
Distinct => write!(f, "DISTINCT"),
In => write!(f, "IN"),
}
}
}
fn format_child_list<T: Display>(children: &[T]) -> String {
children.iter().map(|c| format!("{c}")).join(", ")
}
impl Display for Expression {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use Expression::*;
match self {
Literal(l) => write!(f, "{l}"),
Column(name) => write!(f, "Column({name})"),
Predicate(p) => write!(f, "{p}"),
Struct(exprs, _) => write!(f, "Struct({})", format_child_list(exprs)),
StructPatch(patch) => {
write!(f, "StructPatch(")?;
let mut sep = "";
if !patch.prepended_fields.is_empty() {
let prepended_fields = format_child_list(&patch.prepended_fields);
write!(f, "prepend [{prepended_fields}]")?;
sep = ", ";
}
for (field_name, field_patch) in &patch.field_patches {
if !field_patch.keep_input && field_patch.insertions.is_empty() {
write!(f, "{sep}drop {field_name}")?;
sep = ", ";
}
if !field_patch.insertions.is_empty() {
let insertions = format_child_list(&field_patch.insertions);
let action = if field_patch.keep_input {
"after"
} else {
"replace/after"
};
write!(f, "{sep}{action} {field_name} insert [{insertions}]")?;
sep = ", ";
}
}
if !patch.appended_fields.is_empty() {
let appended_fields = format_child_list(&patch.appended_fields);
write!(f, "{sep}append [{appended_fields}]")?;
}
write!(f, ")")
}
Unary(UnaryExpression { op, expr }) => write!(f, "{op}({expr})"),
Binary(BinaryExpression { op, left, right }) => write!(f, "{left} {op} {right}"),
Variadic(VariadicExpression { op, exprs }) => {
write!(f, "{op}({})", format_child_list(exprs))
}
Opaque(OpaqueExpression { op, exprs }) => {
write!(f, "{op:?}({})", format_child_list(exprs))
}
Unknown(name) => write!(f, "<unknown: {name}>"),
ParseJson(p) => {
write!(
f,
"PARSE_JSON({}, <schema:{} fields>)",
p.json_expr,
p.output_schema.fields().len()
)
}
MapToStruct(m) => write!(f, "MAP_TO_STRUCT({})", m.map_expr),
}
}
}
impl Display for Predicate {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use Predicate::*;
match self {
BooleanExpression(expr) => write!(f, "{expr}"),
Not(pred) => write!(f, "NOT({pred})"),
Binary(BinaryPredicate {
op: BinaryPredicateOp::Distinct,
left,
right,
}) => write!(f, "DISTINCT({left}, {right})"),
Binary(BinaryPredicate { op, left, right }) => write!(f, "{left} {op} {right}"),
Unary(UnaryPredicate { op, expr }) => match op {
UnaryPredicateOp::IsNull => write!(f, "{expr} IS NULL"),
},
Junction(JunctionPredicate { op, preds }) => {
let op = match op {
JunctionPredicateOp::And => "AND",
JunctionPredicateOp::Or => "OR",
};
write!(f, "{op}({})", format_child_list(preds))
}
Opaque(OpaquePredicate { op, exprs }) => {
write!(f, "{op:?}({})", format_child_list(exprs))
}
Unknown(name) => write!(f, "<unknown: {name}>"),
}
}
}
impl From<Scalar> for Expression {
fn from(value: Scalar) -> Self {
Self::literal(value)
}
}
impl From<ColumnName> for Expression {
fn from(value: ColumnName) -> Self {
Self::Column(value)
}
}
impl From<Predicate> for Expression {
fn from(value: Predicate) -> Self {
Self::from_pred(value)
}
}
impl From<ColumnName> for Predicate {
fn from(value: ColumnName) -> Self {
Self::from_expr(value)
}
}
impl<R: Into<Expression>> std::ops::Add<R> for Expression {
type Output = Self;
fn add(self, rhs: R) -> Self::Output {
Self::binary(BinaryExpressionOp::Plus, self, rhs)
}
}
impl<R: Into<Expression>> std::ops::Sub<R> for Expression {
type Output = Self;
fn sub(self, rhs: R) -> Self {
Self::binary(BinaryExpressionOp::Minus, self, rhs)
}
}
impl<R: Into<Expression>> std::ops::Mul<R> for Expression {
type Output = Self;
fn mul(self, rhs: R) -> Self {
Self::binary(BinaryExpressionOp::Multiply, self, rhs)
}
}
impl<R: Into<Expression>> std::ops::Div<R> for Expression {
type Output = Self;
fn div(self, rhs: R) -> Self {
Self::binary(BinaryExpressionOp::Divide, self, rhs)
}
}
#[derive(Default)]
struct GetColumnReferences<'a>(HashSet<&'a ColumnName>);
impl<'a> ExpressionTransform<'a> for GetColumnReferences<'a> {
transform_output_type!(|'a, T| ());
fn transform_expr_column(&mut self, name: &'a ColumnName) {
self.0.insert(name);
}
}
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use serde::de::DeserializeOwned;
use serde::Serialize;
use super::{column_expr, column_pred, Expression as Expr, Predicate as Pred};
fn assert_roundtrip<T: Serialize + DeserializeOwned + PartialEq + Debug>(value: &T) {
let json = serde_json::to_string(value).expect("serialization should succeed");
let deserialized: T = serde_json::from_str(&json).expect("deserialization should succeed");
assert_eq!(value, &deserialized, "roundtrip should preserve value");
}
#[test]
fn test_expression_format() {
let cases = [
(column_expr!("x"), "Column(x)"),
(
(column_expr!("x") + Expr::literal(4)) / Expr::literal(10) * Expr::literal(42),
"Column(x) + 4 / 10 * 42",
),
(
Expr::struct_from([column_expr!("x"), Expr::literal(2), Expr::literal(10)]),
"Struct(Column(x), 2, 10)",
),
(
Expr::array([column_expr!("x"), column_expr!("y"), Expr::literal(0)]),
"ARRAY(Column(x), Column(y), 0)",
),
];
for (expr, expected) in cases {
let result = format!("{expr}");
assert_eq!(result, expected);
}
}
#[test]
fn test_predicate_format() {
let cases = [
(column_pred!("x"), "Column(x)"),
(column_expr!("x").eq(Expr::literal(2)), "Column(x) = 2"),
(
(column_expr!("x") - Expr::literal(4)).lt(Expr::literal(10)),
"Column(x) - 4 < 10",
),
(
Pred::and(
column_expr!("x").ge(Expr::literal(2)),
column_expr!("x").le(Expr::literal(10)),
),
"AND(NOT(Column(x) < 2), NOT(Column(x) > 10))",
),
(
Pred::and_from([
column_expr!("x").ge(Expr::literal(2)),
column_expr!("x").le(Expr::literal(10)),
column_expr!("x").le(Expr::literal(100)),
]),
"AND(NOT(Column(x) < 2), NOT(Column(x) > 10), NOT(Column(x) > 100))",
),
(
Pred::or(
column_expr!("x").gt(Expr::literal(2)),
column_expr!("x").lt(Expr::literal(10)),
),
"OR(Column(x) > 2, Column(x) < 10)",
),
(
column_expr!("x").eq(Expr::literal("foo")),
"Column(x) = 'foo'",
),
];
for (pred, expected) in cases {
let result = format!("{pred}");
assert_eq!(result, expected);
}
}
mod serde_tests {
use std::sync::Arc;
use super::assert_roundtrip;
use crate::expressions::scalars::{ArrayData, DecimalData, MapData, StructData};
use crate::expressions::{
col, column_expr, column_name, lit, BinaryExpressionOp, BinaryPredicateOp, ColumnName,
Expression, ExpressionStructPatchBuilder, Predicate, Scalar, UnaryExpressionOp,
};
use crate::schema::{ArrayType, DataType, DecimalType, MapType, StructField};
use crate::utils::test_utils::assert_result_error_with_message;
#[test]
fn test_literal_scalars_roundtrip() {
let cases: Vec<Expression> = vec![
Expression::literal(42i32), Expression::literal(9999999999i64), Expression::literal(123i16), Expression::literal(42i8), Expression::literal(1.12345677_32), Expression::literal(1.12345667_64), Expression::literal("hello world"),
Expression::literal(true),
Expression::literal(false),
Expression::Literal(Scalar::Timestamp(1234567890000000)),
Expression::Literal(Scalar::TimestampNtz(1234567890000000)),
Expression::Literal(Scalar::Date(19000)),
Expression::Literal(Scalar::Binary(vec![1, 2, 3, 4, 5])),
Expression::Literal(Scalar::Decimal(
DecimalData::try_new(12345i128, DecimalType::try_new(10, 2).unwrap()).unwrap(),
)),
];
for expr in &cases {
assert_roundtrip(expr);
}
}
#[test]
fn test_literal_complex_scalars_roundtrip() {
let cases: Vec<Expression> = vec![
Expression::null_literal(DataType::INTEGER),
Expression::null_literal(DataType::STRING),
Expression::null_literal(DataType::BOOLEAN),
Expression::Literal(Scalar::Array(
ArrayData::try_new(
ArrayType::new(DataType::INTEGER, false),
vec![Scalar::Integer(1), Scalar::Integer(2), Scalar::Integer(3)],
)
.unwrap(),
)),
Expression::Literal(Scalar::Map(
MapData::try_new(
MapType::new(DataType::STRING, DataType::INTEGER, false),
vec![
(Scalar::String("a".to_string()), Scalar::Integer(1)),
(Scalar::String("b".to_string()), Scalar::Integer(2)),
],
)
.unwrap(),
)),
Expression::Literal(Scalar::Struct(
StructData::try_new(
vec![
StructField::nullable("x", DataType::INTEGER),
StructField::nullable("y", DataType::STRING),
],
vec![Scalar::Integer(42), Scalar::String("hello".to_string())],
)
.unwrap(),
)),
];
for expr in &cases {
assert_roundtrip(expr);
}
}
#[test]
fn test_column_expressions_roundtrip() {
let cases: Vec<Expression> = vec![
column_expr!("my_column"),
Expression::column(["parent", "child"]),
Expression::column(["a", "b", "c", "d"]),
];
for expr in &cases {
assert_roundtrip(expr);
}
}
#[test]
fn test_column_names_roundtrip() {
let cases: Vec<ColumnName> = vec![
column_name!("simple"),
ColumnName::new(["a", "b", "c"]),
ColumnName::default(),
];
for col in &cases {
assert_roundtrip(col);
}
}
#[test]
fn test_unary_expression_roundtrip() {
let expr = Expression::unary(UnaryExpressionOp::ToJson, column_expr!("data"));
assert_roundtrip(&expr);
}
#[test]
fn test_binary_expressions_roundtrip() {
let ops = [
BinaryExpressionOp::Plus,
BinaryExpressionOp::Minus,
BinaryExpressionOp::Multiply,
BinaryExpressionOp::Divide,
];
for op in ops {
let expr = Expression::binary(op, column_expr!("a"), Expression::literal(10));
assert_roundtrip(&expr);
}
}
#[test]
fn test_variadic_expression_roundtrip() {
let expr = Expression::coalesce([
column_expr!("a"),
column_expr!("b"),
Expression::literal("default"),
]);
assert_roundtrip(&expr);
}
#[rstest::rstest]
#[case::array_single(Expression::array([Expression::literal(7i32)]))]
#[case::array_mixed(Expression::array([
column_expr!("a"),
column_expr!("b"),
Expression::literal(42i64),
]))]
fn test_array_expression_roundtrip(#[case] expr: Expression) {
assert_roundtrip(&expr);
}
#[test]
fn test_nested_arithmetic_expression_roundtrip() {
let left = Expression::binary(
BinaryExpressionOp::Plus,
column_expr!("a"),
column_expr!("b"),
);
let right = Expression::binary(
BinaryExpressionOp::Minus,
column_expr!("c"),
column_expr!("d"),
);
let mul = Expression::binary(BinaryExpressionOp::Multiply, left, right);
let expr = Expression::binary(BinaryExpressionOp::Divide, mul, Expression::literal(2));
assert_roundtrip(&expr);
}
#[test]
fn test_struct_expression_roundtrip() {
let expr = Expression::struct_from([
Arc::new(column_expr!("x")),
Arc::new(Expression::literal(42)),
Arc::new(Expression::literal("hello")),
]);
assert_roundtrip(&expr);
}
#[test]
fn test_transform_expressions_roundtrip() {
let cases: Vec<Expression> = vec![
Expression::struct_patch(ExpressionStructPatchBuilder::new()).unwrap(),
Expression::struct_patch(ExpressionStructPatchBuilder::new().drop("old_column"))
.unwrap(),
Expression::struct_patch(
ExpressionStructPatchBuilder::new().replace("original", lit(0)),
)
.unwrap(),
Expression::struct_patch(
ExpressionStructPatchBuilder::new()
.insert_after("after_col", col!("new_col"))
.prepend(lit("prepended"))
.append(lit("appended")),
)
.unwrap(),
Expression::struct_patch(
ExpressionStructPatchBuilder::new_nested(["parent", "child"]).drop("to_drop"),
)
.unwrap(),
];
for expr in &cases {
assert_roundtrip(expr);
}
}
#[test]
fn test_expression_wrapping_predicate_roundtrip() {
let pred = Predicate::eq(column_expr!("x"), Expression::literal(10));
let expr = Expression::from_pred(pred);
assert_roundtrip(&expr);
}
#[test]
fn test_expression_unknown_roundtrip() {
let expr = Expression::unknown("some_unknown_function()");
assert_roundtrip(&expr);
}
#[test]
fn test_map_to_struct_expression_roundtrip() {
let cases: Vec<Expression> = vec![
Expression::map_to_struct(column_expr!("pv")),
Expression::map_to_struct(Expression::literal("ignored")),
];
for expr in &cases {
assert_roundtrip(expr);
}
}
#[test]
fn test_predicate_basics_roundtrip() {
let cases: Vec<Predicate> = vec![
Predicate::from_expr(column_expr!("is_active")),
Predicate::literal(true),
Predicate::literal(false),
Predicate::not(Predicate::from_expr(column_expr!("x"))),
Predicate::not(Predicate::not(Predicate::gt(
column_expr!("x"),
Expression::literal(5),
))),
Predicate::unknown("some_unknown_predicate()"),
Predicate::is_null(column_expr!("nullable_col")),
Predicate::is_not_null(column_expr!("nullable_col")),
];
for pred in &cases {
assert_roundtrip(pred);
}
}
#[test]
fn test_predicate_null_literal_roundtrip() {
let pred = Predicate::null_literal();
assert_roundtrip(&pred);
}
#[test]
fn test_predicate_comparisons_roundtrip() {
let cases: Vec<Predicate> = vec![
Predicate::eq(column_expr!("x"), Expression::literal(42)),
Predicate::ne(column_expr!("status"), Expression::literal("active")),
Predicate::lt(column_expr!("age"), Expression::literal(18)),
Predicate::le(column_expr!("price"), Expression::literal(100)),
Predicate::gt(column_expr!("score"), Expression::literal(90)),
Predicate::ge(column_expr!("quantity"), Expression::literal(1)),
Predicate::distinct(column_expr!("a"), column_expr!("b")),
];
for pred in &cases {
assert_roundtrip(pred);
}
}
#[test]
fn test_predicate_in_roundtrip() {
let array_data = ArrayData::try_new(
ArrayType::new(DataType::INTEGER, false),
vec![Scalar::Integer(1), Scalar::Integer(2), Scalar::Integer(3)],
)
.unwrap();
let pred = Predicate::binary(
BinaryPredicateOp::In,
column_expr!("x"),
Expression::Literal(Scalar::Array(array_data)),
);
assert_roundtrip(&pred);
}
#[test]
fn test_predicate_junctions_roundtrip() {
let cases: Vec<Predicate> = vec![
Predicate::and(
Predicate::gt(column_expr!("x"), Expression::literal(0)),
Predicate::lt(column_expr!("x"), Expression::literal(100)),
),
Predicate::or(
Predicate::eq(column_expr!("status"), Expression::literal("active")),
Predicate::eq(column_expr!("status"), Expression::literal("pending")),
),
Predicate::and_from([
Predicate::gt(column_expr!("x"), Expression::literal(0)),
Predicate::lt(column_expr!("x"), Expression::literal(100)),
Predicate::is_not_null(column_expr!("x")),
]),
Predicate::or_from([
Predicate::eq(column_expr!("type"), Expression::literal("A")),
Predicate::eq(column_expr!("type"), Expression::literal("B")),
Predicate::eq(column_expr!("type"), Expression::literal("C")),
]),
Predicate::or(
Predicate::and(
Predicate::gt(column_expr!("a"), Expression::literal(0)),
Predicate::lt(column_expr!("b"), Expression::literal(100)),
),
Predicate::eq(column_expr!("c"), Expression::literal("special")),
),
];
for pred in &cases {
assert_roundtrip(pred);
}
}
#[test]
fn test_deeply_nested_structures_roundtrip() {
let add = Expression::binary(
BinaryExpressionOp::Plus,
column_expr!("a"),
column_expr!("b"),
);
let mul = Expression::binary(
BinaryExpressionOp::Multiply,
column_expr!("c"),
column_expr!("d"),
);
let coalesce = Expression::coalesce([add, mul, Expression::literal(0)]);
let pred = Predicate::gt(coalesce, Expression::literal(100));
assert_roundtrip(&pred);
let inner_pred = Predicate::and(
Predicate::eq(column_expr!("x"), Expression::literal(1)),
Predicate::gt(
Expression::binary(
BinaryExpressionOp::Plus,
column_expr!("y"),
column_expr!("z"),
),
Expression::literal(10),
),
);
let expr = Expression::from_pred(inner_pred);
assert_roundtrip(&expr);
}
#[test]
fn test_opaque_expression_serialize_fails() {
use crate::expressions::{OpaqueExpressionOp, ScalarExpressionEvaluator};
use crate::DeltaResult;
#[derive(Debug, PartialEq)]
struct TestOpaqueExprOp;
impl OpaqueExpressionOp for TestOpaqueExprOp {
fn name(&self) -> &str {
"test_opaque"
}
fn eval_expr_scalar(
&self,
_eval_expr: &ScalarExpressionEvaluator<'_>,
_exprs: &[Expression],
) -> DeltaResult<Scalar> {
Ok(Scalar::Integer(0))
}
}
let expr = Expression::opaque(TestOpaqueExprOp, [Expression::literal(1)]);
let result = serde_json::to_string(&expr);
assert_result_error_with_message(result, "Cannot serialize an Opaque Expression");
}
#[test]
fn test_opaque_predicate_serialize_fails() {
use crate::expressions::{OpaquePredicateOp, ScalarExpressionEvaluator};
use crate::kernel_predicates::{
DirectDataSkippingPredicateEvaluator, DirectPredicateEvaluator,
IndirectDataSkippingPredicateEvaluator,
};
use crate::DeltaResult;
#[derive(Debug, PartialEq)]
struct TestOpaquePredOp;
impl OpaquePredicateOp for TestOpaquePredOp {
fn name(&self) -> &str {
"test_opaque_pred"
}
fn eval_pred_scalar(
&self,
_eval_expr: &ScalarExpressionEvaluator<'_>,
_eval_pred: &DirectPredicateEvaluator<'_>,
_exprs: &[Expression],
_inverted: bool,
) -> DeltaResult<Option<bool>> {
Ok(Some(true))
}
fn eval_as_data_skipping_predicate(
&self,
_evaluator: &DirectDataSkippingPredicateEvaluator<'_>,
_exprs: &[Expression],
_inverted: bool,
) -> Option<bool> {
Some(true)
}
fn as_data_skipping_predicate(
&self,
_evaluator: &IndirectDataSkippingPredicateEvaluator<'_>,
_exprs: &[Expression],
_inverted: bool,
) -> Option<Predicate> {
None
}
}
let pred = Predicate::opaque(TestOpaquePredOp, [Expression::literal(1)]);
let result = serde_json::to_string(&pred);
assert_result_error_with_message(result, "Cannot serialize an Opaque Predicate");
}
}
#[test]
fn single_element_and_from_returns_unwrapped_predicate() {
let inner = Pred::gt(column_expr!("x"), Expr::literal(0));
let result = Pred::and_from([inner.clone()]);
assert_eq!(result, inner);
}
#[test]
fn single_element_or_from_returns_unwrapped_predicate() {
let inner = Pred::gt(column_expr!("x"), Expr::literal(0));
let result = Pred::or_from([inner.clone()]);
assert_eq!(result, inner);
}
#[test]
fn multi_element_and_from_returns_junction() {
let p1 = Pred::gt(column_expr!("x"), Expr::literal(0));
let p2 = Pred::lt(column_expr!("x"), Expr::literal(100));
let result = Pred::and_from([p1.clone(), p2.clone()]);
assert!(matches!(result, Pred::Junction(ref j) if j.preds.len() == 2));
assert_eq!(result, Pred::and(p1, p2));
}
#[test]
fn empty_and_from_returns_identity_literal() {
let result = Pred::and_from(std::iter::empty());
assert_eq!(result, Pred::literal(true));
}
#[test]
fn empty_or_from_returns_identity_literal() {
let result = Pred::or_from(std::iter::empty());
assert_eq!(result, Pred::literal(false));
}
}