use std::sync::Arc;
use vortex_error::VortexExpect;
use vortex_error::vortex_panic;
use vortex_utils::iter::ReduceBalancedIterExt;
use crate::dtype::DType;
use crate::dtype::FieldName;
use crate::dtype::FieldNames;
use crate::dtype::Nullability;
use crate::expr::Expression;
use crate::scalar::Scalar;
use crate::scalar::ScalarValue;
use crate::scalar_fn::EmptyOptions;
use crate::scalar_fn::ScalarFnVTableExt;
use crate::scalar_fn::fns::between::Between;
use crate::scalar_fn::fns::between::BetweenOptions;
use crate::scalar_fn::fns::binary::Binary;
use crate::scalar_fn::fns::case_when::CaseWhen;
use crate::scalar_fn::fns::case_when::CaseWhenOptions;
use crate::scalar_fn::fns::cast::Cast;
use crate::scalar_fn::fns::dynamic::DynamicComparison;
use crate::scalar_fn::fns::dynamic::DynamicComparisonExpr;
use crate::scalar_fn::fns::dynamic::Rhs;
use crate::scalar_fn::fns::fill_null::FillNull;
use crate::scalar_fn::fns::get_item::GetItem;
use crate::scalar_fn::fns::is_null::IsNull;
use crate::scalar_fn::fns::like::Like;
use crate::scalar_fn::fns::like::LikeOptions;
use crate::scalar_fn::fns::list_contains::ListContains;
use crate::scalar_fn::fns::literal::Literal;
use crate::scalar_fn::fns::mask::Mask;
use crate::scalar_fn::fns::merge::DuplicateHandling;
use crate::scalar_fn::fns::merge::Merge;
use crate::scalar_fn::fns::not::Not;
use crate::scalar_fn::fns::operators::CompareOperator;
use crate::scalar_fn::fns::operators::Operator;
use crate::scalar_fn::fns::pack::Pack;
use crate::scalar_fn::fns::pack::PackOptions;
use crate::scalar_fn::fns::root::Root;
use crate::scalar_fn::fns::select::FieldSelection;
use crate::scalar_fn::fns::select::Select;
use crate::scalar_fn::fns::zip::Zip;
pub fn root() -> Expression {
Root.try_new_expr(EmptyOptions, vec![])
.vortex_expect("Failed to create Root expression")
}
pub fn is_root(expr: &Expression) -> bool {
expr.is::<Root>()
}
pub fn lit(value: impl Into<Scalar>) -> Expression {
Literal.new_expr(value.into(), [])
}
pub fn col(field: impl Into<FieldName>) -> Expression {
GetItem.new_expr(field.into(), vec![root()])
}
pub fn get_item(field: impl Into<FieldName>, child: Expression) -> Expression {
GetItem.new_expr(field.into(), vec![child])
}
pub fn case_when(
condition: Expression,
then_value: Expression,
else_value: Expression,
) -> Expression {
let options = CaseWhenOptions {
num_when_then_pairs: 1,
has_else: true,
};
CaseWhen.new_expr(options, [condition, then_value, else_value])
}
pub fn case_when_no_else(condition: Expression, then_value: Expression) -> Expression {
let options = CaseWhenOptions {
num_when_then_pairs: 1,
has_else: false,
};
CaseWhen.new_expr(options, [condition, then_value])
}
pub fn nested_case_when(
when_then_pairs: Vec<(Expression, Expression)>,
else_value: Option<Expression>,
) -> Expression {
assert!(
!when_then_pairs.is_empty(),
"nested_case_when requires at least one when/then pair"
);
let has_else = else_value.is_some();
let mut children = Vec::with_capacity(when_then_pairs.len() * 2 + usize::from(has_else));
for (condition, then_value) in &when_then_pairs {
children.push(condition.clone());
children.push(then_value.clone());
}
if let Some(else_expr) = else_value {
children.push(else_expr);
}
let Ok(num_when_then_pairs) = u32::try_from(when_then_pairs.len()) else {
vortex_panic!("nested_case_when has too many when/then pairs");
};
let options = CaseWhenOptions {
num_when_then_pairs,
has_else,
};
CaseWhen.new_expr(options, children)
}
pub fn eq(lhs: Expression, rhs: Expression) -> Expression {
Binary
.try_new_expr(Operator::Eq, [lhs, rhs])
.vortex_expect("Failed to create Eq binary expression")
}
pub fn not_eq(lhs: Expression, rhs: Expression) -> Expression {
Binary
.try_new_expr(Operator::NotEq, [lhs, rhs])
.vortex_expect("Failed to create NotEq binary expression")
}
pub fn gt_eq(lhs: Expression, rhs: Expression) -> Expression {
Binary
.try_new_expr(Operator::Gte, [lhs, rhs])
.vortex_expect("Failed to create Gte binary expression")
}
pub fn gt(lhs: Expression, rhs: Expression) -> Expression {
Binary
.try_new_expr(Operator::Gt, [lhs, rhs])
.vortex_expect("Failed to create Gt binary expression")
}
pub fn lt_eq(lhs: Expression, rhs: Expression) -> Expression {
Binary
.try_new_expr(Operator::Lte, [lhs, rhs])
.vortex_expect("Failed to create Lte binary expression")
}
pub fn lt(lhs: Expression, rhs: Expression) -> Expression {
Binary
.try_new_expr(Operator::Lt, [lhs, rhs])
.vortex_expect("Failed to create Lt binary expression")
}
pub fn or(lhs: Expression, rhs: Expression) -> Expression {
Binary
.try_new_expr(Operator::Or, [lhs, rhs])
.vortex_expect("Failed to create Or binary expression")
}
pub fn or_collect<I>(iter: I) -> Option<Expression>
where
I: IntoIterator<Item = Expression>,
{
iter.into_iter().reduce_balanced(or)
}
pub fn and(lhs: Expression, rhs: Expression) -> Expression {
Binary
.try_new_expr(Operator::And, [lhs, rhs])
.vortex_expect("Failed to create And binary expression")
}
pub fn and_collect<I>(iter: I) -> Option<Expression>
where
I: IntoIterator<Item = Expression>,
{
iter.into_iter().reduce_balanced(and)
}
pub fn checked_add(lhs: Expression, rhs: Expression) -> Expression {
Binary
.try_new_expr(Operator::Add, [lhs, rhs])
.vortex_expect("Failed to create Add binary expression")
}
pub fn not(operand: Expression) -> Expression {
Not.new_expr(EmptyOptions, vec![operand])
}
pub fn between(
arr: Expression,
lower: Expression,
upper: Expression,
options: BetweenOptions,
) -> Expression {
Between
.try_new_expr(options, [arr, lower, upper])
.vortex_expect("Failed to create Between expression")
}
pub fn select(field_names: impl Into<FieldNames>, child: Expression) -> Expression {
Select
.try_new_expr(FieldSelection::Include(field_names.into()), [child])
.vortex_expect("Failed to create Select expression")
}
pub fn select_exclude(fields: impl Into<FieldNames>, child: Expression) -> Expression {
Select
.try_new_expr(FieldSelection::Exclude(fields.into()), [child])
.vortex_expect("Failed to create Select expression")
}
pub fn pack(
elements: impl IntoIterator<Item = (impl Into<FieldName>, Expression)>,
nullability: Nullability,
) -> Expression {
let (names, values): (Vec<_>, Vec<_>) = elements
.into_iter()
.map(|(name, value)| (name.into(), value))
.unzip();
Pack.new_expr(
PackOptions {
names: names.into(),
nullability,
},
values,
)
}
pub fn cast(child: Expression, target: DType) -> Expression {
Cast.try_new_expr(target, [child])
.vortex_expect("Failed to create Cast expression")
}
pub fn fill_null(child: Expression, fill_value: Expression) -> Expression {
FillNull.new_expr(EmptyOptions, [child, fill_value])
}
pub fn is_null(child: Expression) -> Expression {
IsNull.new_expr(EmptyOptions, vec![child])
}
pub fn like(child: Expression, pattern: Expression) -> Expression {
Like.new_expr(
LikeOptions {
negated: false,
case_insensitive: false,
},
[child, pattern],
)
}
pub fn ilike(child: Expression, pattern: Expression) -> Expression {
Like.new_expr(
LikeOptions {
negated: false,
case_insensitive: true,
},
[child, pattern],
)
}
pub fn not_like(child: Expression, pattern: Expression) -> Expression {
Like.new_expr(
LikeOptions {
negated: true,
case_insensitive: false,
},
[child, pattern],
)
}
pub fn not_ilike(child: Expression, pattern: Expression) -> Expression {
Like.new_expr(
LikeOptions {
negated: true,
case_insensitive: true,
},
[child, pattern],
)
}
pub fn mask(array: Expression, mask: Expression) -> Expression {
Mask.new_expr(EmptyOptions, [array, mask])
}
pub fn merge(elements: impl IntoIterator<Item = impl Into<Expression>>) -> Expression {
use itertools::Itertools as _;
let values = elements.into_iter().map(|value| value.into()).collect_vec();
Merge.new_expr(DuplicateHandling::default(), values)
}
pub fn merge_opts(
elements: impl IntoIterator<Item = impl Into<Expression>>,
duplicate_handling: DuplicateHandling,
) -> Expression {
use itertools::Itertools as _;
let values = elements.into_iter().map(|value| value.into()).collect_vec();
Merge.new_expr(duplicate_handling, values)
}
pub fn zip_expr(mask: Expression, if_true: Expression, if_false: Expression) -> Expression {
Zip.new_expr(EmptyOptions, [if_true, if_false, mask])
}
pub fn dynamic(
operator: CompareOperator,
rhs_value: impl Fn() -> Option<ScalarValue> + Send + Sync + 'static,
rhs_dtype: DType,
default: bool,
lhs: Expression,
) -> Expression {
DynamicComparison.new_expr(
DynamicComparisonExpr {
operator,
rhs: Arc::new(Rhs {
value: Arc::new(rhs_value),
dtype: rhs_dtype,
}),
default,
},
[lhs],
)
}
pub fn list_contains(list: Expression, value: Expression) -> Expression {
ListContains.new_expr(EmptyOptions, [list, value])
}