use crate::expr::{
HigherOrderFunction, display_comma_separated,
schema_name_from_exprs_comma_separated_without_space,
};
use crate::type_coercion::functions::value_fields_with_higher_order_udf;
use crate::udf_eq::UdfEq;
use crate::{ColumnarValue, Documentation, Expr, ExprSchemable};
use arrow::array::{ArrayRef, RecordBatch};
use arrow::datatypes::{DataType, FieldRef, Schema};
use arrow_schema::SchemaRef;
use datafusion_common::config::ConfigOptions;
use datafusion_common::datatype::FieldExt;
use datafusion_common::hash_map::EntryRef;
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
};
use datafusion_common::{
DFSchema, HashMap, HashSet, Result, ScalarValue, exec_err, internal_datafusion_err,
internal_err, not_impl_err, plan_datafusion_err, plan_err,
};
use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
use datafusion_expr_common::signature::Volatility;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::any::Any;
use std::cmp::Ordering;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::mem;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum HigherOrderTypeSignature {
UserDefined,
VariadicAny,
Any(usize),
Exact(Vec<ValueOrLambda<(), ()>>),
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub struct HigherOrderSignature {
pub type_signature: HigherOrderTypeSignature,
pub volatility: Volatility,
pub lambda_parameters_max_iterations: usize,
}
const LAMBDA_PARAMETERS_MAX_ITERATIONS: usize = 256;
impl HigherOrderSignature {
pub fn new(type_signature: HigherOrderTypeSignature, volatility: Volatility) -> Self {
HigherOrderSignature {
type_signature,
volatility,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}
pub fn user_defined(volatility: Volatility) -> Self {
Self {
type_signature: HigherOrderTypeSignature::UserDefined,
volatility,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}
pub fn variadic_any(volatility: Volatility) -> Self {
Self {
type_signature: HigherOrderTypeSignature::VariadicAny,
volatility,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}
pub fn any(arg_count: usize, volatility: Volatility) -> Self {
Self {
type_signature: HigherOrderTypeSignature::Any(arg_count),
volatility,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}
pub fn exact(args: Vec<ValueOrLambda<(), ()>>, volatility: Volatility) -> Self {
Self {
type_signature: HigherOrderTypeSignature::Exact(args),
volatility,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}
}
impl PartialEq for dyn HigherOrderUDFImpl {
fn eq(&self, other: &Self) -> bool {
self.dyn_eq(other as _)
}
}
impl PartialOrd for dyn HigherOrderUDFImpl {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
let mut cmp = self.name().cmp(other.name());
if cmp == Ordering::Equal {
cmp = self.signature().partial_cmp(other.signature())?;
}
if cmp == Ordering::Equal {
cmp = self.aliases().partial_cmp(other.aliases())?;
}
if cmp == Ordering::Equal && self != other {
return None;
}
debug_assert!(
cmp == Ordering::Equal || self != other,
"Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
The functions compare as equal, but they are not equal based on general properties that \
the PartialOrd implementation observes,",
self.name(),
other.name()
);
Some(cmp)
}
}
impl Eq for dyn HigherOrderUDFImpl {}
impl Hash for dyn HigherOrderUDFImpl {
fn hash<H: Hasher>(&self, state: &mut H) {
self.dyn_hash(state)
}
}
#[derive(Debug, Clone)]
pub struct HigherOrderFunctionArgs {
pub args: Vec<ValueOrLambda<ColumnarValue, LambdaArgument>>,
pub arg_fields: Vec<ValueOrLambda<FieldRef, FieldRef>>,
pub number_rows: usize,
pub return_field: FieldRef,
pub config_options: Arc<ConfigOptions>,
}
impl HigherOrderFunctionArgs {
pub fn return_type(&self) -> &DataType {
self.return_field.data_type()
}
}
#[derive(Clone, Debug)]
pub struct LambdaArgument {
params: Vec<FieldRef>,
body: Arc<dyn PhysicalExpr>,
schema: SchemaRef,
captures: Option<RecordBatch>,
}
impl LambdaArgument {
pub fn new(
params: Vec<FieldRef>,
body: Arc<dyn PhysicalExpr>,
captures: Option<RecordBatch>,
) -> Self {
let fields = match &captures {
Some(batch) => batch
.schema_ref()
.fields()
.iter()
.cloned()
.chain(params.clone())
.collect(),
None => params.clone(),
};
let schema = Arc::new(Schema::new(fields));
Self {
params,
body,
schema,
captures,
}
}
pub fn evaluate(
&self,
args: &[&dyn Fn() -> Result<ArrayRef>],
spread_captures: impl FnOnce(&[ArrayRef]) -> Result<Vec<ArrayRef>>,
) -> Result<ColumnarValue> {
let spread_captures = self
.captures
.as_ref()
.map(|captures| {
let spread_columns = spread_captures(captures.columns())?;
RecordBatch::try_new(captures.schema(), spread_columns)
})
.transpose()?;
let merged = merge_captures_with_variables(
spread_captures.as_ref(),
Arc::clone(&self.schema),
&self.params,
args,
)?;
self.body.evaluate(&merged)
}
}
fn merge_captures_with_variables(
captures: Option<&RecordBatch>,
schema: SchemaRef,
params: &[FieldRef],
variables: &[&dyn Fn() -> Result<ArrayRef>],
) -> Result<RecordBatch> {
if variables.len() < params.len() {
return exec_err!(
"expected at least {} lambda arguments to merge with captures, got {}",
params.len(),
variables.len()
);
}
let columns = match captures {
Some(captures) => {
let mut columns = captures.columns().to_vec();
for arg in &variables[..params.len()] {
columns.push(arg()?);
}
columns
}
None => variables
.iter()
.take(params.len())
.map(|arg| arg())
.collect::<Result<_>>()?,
};
Ok(RecordBatch::try_new(schema, columns)?)
}
#[derive(Clone, Debug)]
pub struct HigherOrderReturnFieldArgs<'a> {
pub arg_fields: &'a [ValueOrLambda<FieldRef, FieldRef>],
pub scalar_arguments: &'a [Option<&'a ScalarValue>],
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)]
pub enum ValueOrLambda<V, L> {
Value(V),
Lambda(L),
}
pub enum LambdaParametersProgress {
Partial(Vec<Option<Vec<FieldRef>>>),
Complete(Vec<Vec<FieldRef>>),
}
pub trait HigherOrderUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any {
fn name(&self) -> &str;
fn aliases(&self) -> &[String] {
&[]
}
fn schema_name(&self, args: &[Expr]) -> Result<String> {
Ok(format!(
"{}({})",
self.name(),
schema_name_from_exprs_comma_separated_without_space(args)?
))
}
fn signature(&self) -> &HigherOrderSignature;
fn lambda_parameters(
&self,
step: usize,
fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
) -> Result<LambdaParametersProgress>;
fn coerce_values_for_lambdas(
&self,
_fields: &[ValueOrLambda<DataType, DataType>],
) -> Result<Option<Vec<DataType>>> {
Ok(None)
}
fn return_field_from_args(
&self,
args: HigherOrderReturnFieldArgs,
) -> Result<FieldRef>;
fn clear_null_values(&self) -> bool {
true
}
fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue>;
fn short_circuits(&self) -> bool {
false
}
fn conditional_arguments<'a>(
&self,
args: &'a [Expr],
) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
if self.short_circuits() {
Some((vec![], args.iter().collect()))
} else {
None
}
}
fn coerce_value_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!(
"Function {} does not implement coerce_value_types",
self.name()
)
}
fn documentation(&self) -> Option<&Documentation> {
None
}
}
#[derive(Debug, Clone)]
pub struct HigherOrderUDF {
inner: Arc<dyn HigherOrderUDFImpl>,
}
impl PartialEq for HigherOrderUDF {
fn eq(&self, other: &Self) -> bool {
self.inner.as_ref().dyn_eq(other.inner.as_ref())
}
}
impl PartialOrd for HigherOrderUDF {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
let mut cmp = self.name().cmp(other.name());
if cmp == Ordering::Equal {
cmp = self.signature().partial_cmp(other.signature())?;
}
if cmp == Ordering::Equal {
cmp = self.aliases().partial_cmp(other.aliases())?;
}
if cmp == Ordering::Equal && self != other {
return None;
}
debug_assert!(
cmp == Ordering::Equal || self != other,
"Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
The functions compare as equal, but they are not equal based on general properties that \
the PartialOrd implementation observes,",
self.name(),
other.name()
);
Some(cmp)
}
}
impl Eq for HigherOrderUDF {}
impl Hash for HigherOrderUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.dyn_hash(state)
}
}
impl HigherOrderUDF {
pub fn new_from_impl<F>(fun: F) -> HigherOrderUDF
where
F: HigherOrderUDFImpl + 'static,
{
Self::new_from_shared_impl(Arc::new(fun))
}
pub fn new_from_shared_impl(fun: Arc<dyn HigherOrderUDFImpl>) -> HigherOrderUDF {
Self { inner: fun }
}
pub fn inner(&self) -> &Arc<dyn HigherOrderUDFImpl> {
&self.inner
}
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
Self::new_from_impl(AliasedHigherOrderUDFImpl::new(
Arc::clone(&self.inner),
aliases,
))
}
pub fn name(&self) -> &str {
self.inner.name()
}
pub fn aliases(&self) -> &[String] {
self.inner.aliases()
}
pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
self.inner.schema_name(args)
}
pub fn signature(&self) -> &HigherOrderSignature {
self.inner.signature()
}
pub fn lambda_parameters(
&self,
step: usize,
fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
) -> Result<LambdaParametersProgress> {
self.inner.lambda_parameters(step, fields)
}
pub fn coerce_values_for_lambdas(
&self,
fields: &[ValueOrLambda<DataType, DataType>],
) -> Result<Option<Vec<DataType>>> {
self.inner.coerce_values_for_lambdas(fields)
}
pub fn return_field_from_args(
&self,
args: HigherOrderReturnFieldArgs,
) -> Result<FieldRef> {
self.inner.return_field_from_args(args)
}
pub fn clear_null_values(&self) -> bool {
self.inner.clear_null_values()
}
pub fn invoke_with_args(
&self,
args: HigherOrderFunctionArgs,
) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}
pub fn short_circuits(&self) -> bool {
self.inner.short_circuits()
}
pub fn conditional_arguments<'a>(
&self,
args: &'a [Expr],
) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
self.inner.conditional_arguments(args)
}
pub fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_value_types(arg_types)
}
pub fn documentation(&self) -> Option<&Documentation> {
self.inner.documentation()
}
}
impl<F> From<F> for HigherOrderUDF
where
F: HigherOrderUDFImpl + 'static,
{
fn from(fun: F) -> Self {
Self::new_from_impl(fun)
}
}
#[derive(Debug, PartialEq, Eq, Hash)]
struct AliasedHigherOrderUDFImpl {
inner: UdfEq<Arc<dyn HigherOrderUDFImpl>>,
aliases: Vec<String>,
}
impl AliasedHigherOrderUDFImpl {
fn new(
inner: Arc<dyn HigherOrderUDFImpl>,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
Self {
inner: inner.into(),
aliases,
}
}
}
#[warn(clippy::missing_trait_methods)] impl HigherOrderUDFImpl for AliasedHigherOrderUDFImpl {
fn name(&self) -> &str {
self.inner.name()
}
fn aliases(&self) -> &[String] {
&self.aliases
}
fn schema_name(&self, args: &[Expr]) -> Result<String> {
self.inner.schema_name(args)
}
fn signature(&self) -> &HigherOrderSignature {
self.inner.signature()
}
fn lambda_parameters(
&self,
step: usize,
fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
) -> Result<LambdaParametersProgress> {
self.inner.lambda_parameters(step, fields)
}
fn coerce_values_for_lambdas(
&self,
fields: &[ValueOrLambda<DataType, DataType>],
) -> Result<Option<Vec<DataType>>> {
self.inner.coerce_values_for_lambdas(fields)
}
fn return_field_from_args(
&self,
args: HigherOrderReturnFieldArgs,
) -> Result<FieldRef> {
self.inner.return_field_from_args(args)
}
fn clear_null_values(&self) -> bool {
self.inner.clear_null_values()
}
fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}
fn short_circuits(&self) -> bool {
self.inner.short_circuits()
}
fn conditional_arguments<'a>(
&self,
args: &'a [Expr],
) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
self.inner.conditional_arguments(args)
}
fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_value_types(arg_types)
}
fn documentation(&self) -> Option<&Documentation> {
self.inner.documentation()
}
}
pub(crate) fn resolve_lambda_variables(
expr: Expr,
schema: &DFSchema,
vars: &mut HashMap<String, Vec<FieldRef>>,
) -> Result<Transformed<Expr>> {
expr.transform_down(|expr| match expr {
Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
resolve_higher_order_function(func, args, schema, vars)
}
Expr::LambdaVariable(mut var) => {
let field_stack = vars.get(&var.name).ok_or_else(|| {
plan_datafusion_err!(
"missing field of lambda variable {} while resolving",
var.name
)
})?;
let field = field_stack.last().ok_or_else(|| {
internal_datafusion_err!("every entry should have at least one field")
})?;
let field = Arc::clone(field).renamed(&var.name);
let transformed = var.field.as_ref().is_none_or(|old| old != &field);
var.field = Some(field);
Ok(Transformed::new_transformed(
Expr::LambdaVariable(var),
transformed,
))
}
_ => Ok(Transformed::no(expr)),
})
}
fn resolve_higher_order_function(
func: Arc<HigherOrderUDF>,
args: Vec<Expr>,
schema: &DFSchema,
vars: &mut HashMap<String, Vec<FieldRef>>,
) -> Result<Transformed<Expr>> {
let args = if !vars.is_empty() {
args.map_elements(|arg| match arg {
Expr::Lambda(_) => Ok(Transformed::no(arg)),
_ => resolve_lambda_variables(arg, schema, vars),
})?
} else {
Transformed::no(args)
};
let transformed = args.transformed;
let mut args = args.data;
let current_fields = args
.iter()
.map(|e| match e {
Expr::Lambda(_lambda_function) => Ok(ValueOrLambda::Lambda(None)),
_ => Ok(ValueOrLambda::Value(e.to_field(schema)?.1)),
})
.collect::<Result<Vec<_>>>()?;
let mut fields = value_fields_with_higher_order_udf(¤t_fields, func.as_ref())?;
let num_lambdas = args.iter().filter(|a| matches!(a, Expr::Lambda(_))).count();
let mut step = 0;
let lambda_params = loop {
match func.lambda_parameters(step, &fields)? {
LambdaParametersProgress::Partial(params) => {
let mut params = params.into_iter();
if params.len() != num_lambdas {
return plan_err!(
"{} lambda_parameters returned {} lambdas but {num_lambdas} expected",
func.name(),
params.len()
);
}
for (arg, field) in std::iter::zip(&mut args, &mut fields) {
match (arg, field) {
(Expr::Lambda(lambda), ValueOrLambda::Lambda(field)) => {
let params = params.next().ok_or_else(|| {
internal_datafusion_err!(
"params len should have been checked above"
)
})?;
if let Some(params) = params {
for (name, field) in
std::iter::zip(&lambda.params, params)
{
vars.entry_ref(name)
.or_default()
.push(field.renamed(name.as_str()));
}
let body_with_vars = resolve_lambda_variables(
mem::take(lambda.body.as_mut()),
schema,
vars,
)?;
remove_scope(vars, &lambda.params)?;
*field = Some(body_with_vars.data.to_field(schema)?.1);
*lambda.body = body_with_vars.data;
}
}
(_, ValueOrLambda::Lambda(_)) => {
return internal_err!(
"value_fields_with_higher_order_udf returned a value for a lambda argument"
);
}
(Expr::Lambda(_), ValueOrLambda::Value(_)) => {
return internal_err!(
"value_fields_with_higher_order_udf returned a lambda for a value argument"
);
}
(_, ValueOrLambda::Value(_)) => {} }
}
}
LambdaParametersProgress::Complete(params) => break params,
}
let limit = func.signature().lambda_parameters_max_iterations;
step += 1;
if step > limit {
return plan_err!(
"{} lambda_parameters called {limit} times without completion",
func.name()
);
}
};
let mut lambda_params = lambda_params.into_iter();
if num_lambdas != lambda_params.len() {
return plan_err!(
"{} lambda_parameters returned {} values for {num_lambdas} lambdas",
func.name(),
lambda_params.len()
);
}
let args = args.map_elements(|arg| match arg {
Expr::Lambda(mut lambda) => {
let lambda_params = lambda_params.next().ok_or_else(|| {
internal_datafusion_err!(
"lambda_params len should have been checked above"
)
})?;
if lambda.params.len() > lambda_params.len() {
return plan_err!(
"{} lambda defined {} params ({}), but only {} supported",
func.name(),
lambda.params.len(),
display_comma_separated(&lambda.params),
lambda_params.len()
);
}
if !all_unique(&lambda.params) {
return plan_err!(
"lambda params must be unique, got ({})",
lambda.params.join(", ")
);
}
for (param, field) in std::iter::zip(&lambda.params, lambda_params) {
vars.entry_ref(param)
.or_default()
.push(field.renamed(param.as_str()));
}
let transformed =
resolve_lambda_variables(mem::take(lambda.body.as_mut()), schema, vars)?;
*lambda.body = transformed.data;
remove_scope(vars, &lambda.params)?;
Ok(Transformed::new(
Expr::Lambda(lambda),
transformed.transformed,
TreeNodeRecursion::Jump,
))
}
arg => Ok(Transformed::no(arg)), })?;
Ok(Transformed::new(
Expr::HigherOrderFunction(HigherOrderFunction::new(func, args.data)),
transformed || args.transformed,
TreeNodeRecursion::Jump,
))
}
fn remove_scope(
vars: &mut HashMap<String, Vec<FieldRef>>,
scope: &[String],
) -> Result<()> {
for param in scope {
match vars.entry_ref(param) {
EntryRef::Occupied(mut v) => {
if v.get().len() == 1 {
v.remove();
} else {
v.get_mut().pop().ok_or_else(|| {
internal_datafusion_err!(
"every entry should have at least one field"
)
})?;
}
}
EntryRef::Vacant(_v) => {
return internal_err!("no empty value should be in the map");
}
}
}
Ok(())
}
fn all_unique(params: &[String]) -> bool {
match params.len() {
0 | 1 => true,
2 => params[0] != params[1],
_ => {
let mut set = HashSet::with_capacity(params.len());
params.iter().all(|p| set.insert(p.as_str()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::hash::DefaultHasher;
use std::sync::Arc;
use arrow_schema::{DataType, Field, FieldRef, Schema};
use datafusion_common::{DFSchema, Result};
use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_expr_common::signature::Volatility;
use crate::{
Expr, HigherOrderSignature, HigherOrderUDF, HigherOrderUDFImpl,
LambdaParametersProgress, ValueOrLambda, col,
expr::{HigherOrderFunction, LambdaVariable},
lambda, lambda_var, lit,
};
#[derive(Debug, PartialEq, Eq, Hash)]
struct TestHigherOrderUDF {
name: &'static str,
field: &'static str,
signature: HigherOrderSignature,
}
impl HigherOrderUDFImpl for TestHigherOrderUDF {
fn name(&self) -> &str {
self.name
}
fn signature(&self) -> &HigherOrderSignature {
&self.signature
}
fn lambda_parameters(
&self,
_step: usize,
_fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
) -> Result<LambdaParametersProgress> {
unimplemented!()
}
fn return_field_from_args(
&self,
_args: HigherOrderReturnFieldArgs,
) -> Result<FieldRef> {
unimplemented!()
}
fn invoke_with_args(
&self,
_args: HigherOrderFunctionArgs,
) -> Result<ColumnarValue> {
unimplemented!()
}
}
#[test]
fn test_partial_eq_hash_and_partial_ord() {
let f = test_func("foo", "a");
let f2 = test_func("foo", "a");
assert_eq!(&f, &f2);
assert_eq!(hash(&f), hash(&f2));
assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
let b = test_func("foo", "b");
assert_ne!(&f, &b);
assert_ne!(hash(&f), hash(&b)); assert_eq!(f.partial_cmp(&b), None);
let o = test_func("other", "a");
assert_ne!(&f, &o);
assert_ne!(hash(&f), hash(&o)); assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
assert_ne!(&b, &o);
assert_ne!(hash(&b), hash(&o)); assert_eq!(b.partial_cmp(&o), Some(Ordering::Less));
}
fn test_func(name: &'static str, parameter: &'static str) -> Arc<HigherOrderUDF> {
Arc::new(HigherOrderUDF::new_from_impl(TestHigherOrderUDF {
name,
field: parameter,
signature: HigherOrderSignature::variadic_any(Volatility::Immutable),
}))
}
fn hash<T: Hash>(value: &T) -> u64 {
let hasher = &mut DefaultHasher::new();
value.hash(hasher);
hasher.finish()
}
#[derive(Debug, PartialEq, Eq, Hash)]
struct MockArrayReduce {
signature: HigherOrderSignature,
}
impl HigherOrderUDFImpl for MockArrayReduce {
fn name(&self) -> &str {
"array_reduce"
}
fn aliases(&self) -> &[String] {
&[]
}
fn signature(&self) -> &HigherOrderSignature {
&self.signature
}
fn lambda_parameters(
&self,
step: usize,
fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
) -> Result<LambdaParametersProgress> {
let [
ValueOrLambda::Value(list),
ValueOrLambda::Value(initial_value),
ValueOrLambda::Lambda(merge),
ValueOrLambda::Lambda(_finish),
] = fields
else {
unreachable!()
};
let list_field = match list.data_type() {
DataType::List(field) => field,
_ => unreachable!(),
};
Ok(match (step, merge) {
(0, None) => {
LambdaParametersProgress::Partial(vec![
Some(vec![Arc::clone(initial_value), Arc::clone(list_field)]),
None,
])
}
(1, Some(accumulator)) | (0, Some(accumulator)) => {
LambdaParametersProgress::Complete(vec![
vec![Arc::clone(accumulator), Arc::clone(list_field)],
vec![Arc::clone(accumulator)],
])
}
(1, None) => {
unreachable!()
}
_ => unreachable!(),
})
}
fn return_field_from_args(
&self,
args: HigherOrderReturnFieldArgs,
) -> Result<FieldRef> {
let [
ValueOrLambda::Value(_list),
ValueOrLambda::Value(_initial_value),
ValueOrLambda::Lambda(_merge),
ValueOrLambda::Lambda(finish),
] = args.arg_fields
else {
unreachable!()
};
Ok(Arc::clone(finish))
}
fn invoke_with_args(
&self,
_args: HigherOrderFunctionArgs,
) -> Result<ColumnarValue> {
unreachable!()
}
}
#[test]
fn test_resolve_lambda_variables() {
let schema = DFSchema::try_from(Schema::new(vec![Field::new(
"c",
DataType::new_list(DataType::new_list(DataType::Int32, true), true),
true,
)]))
.unwrap();
let func = Arc::new(HigherOrderUDF::new_from_impl(MockArrayReduce {
signature: HigherOrderSignature::variadic_any(Volatility::Immutable),
}));
let expr = Expr::HigherOrderFunction(HigherOrderFunction::new(
Arc::clone(&func),
vec![
col("c"),
lit(0),
lambda(
["acc1", "v"],
lambda_var("acc1")
+ Expr::HigherOrderFunction(HigherOrderFunction::new(
Arc::clone(&func),
vec![
lambda_var("v"),
lit(0),
lambda(
["acc2", "v"],
lambda_var("acc2")
+ lambda_var("acc1")
+ lambda_var("v"),
),
lambda(["reduced"], lambda_var("reduced") * lit(2.0)),
],
)),
),
lambda(["reduced"], lambda_var("reduced") * lit(2)),
],
));
let resolved_expr = expr.resolve_lambda_variables(&schema).unwrap().data;
let expected = Expr::HigherOrderFunction(HigherOrderFunction::new(
Arc::clone(&func),
vec![
col("c"),
lit(0),
lambda(
["acc1", "v"],
resolved_lambda_var("acc1", DataType::Float64, true)
+ Expr::HigherOrderFunction(HigherOrderFunction::new(
Arc::clone(&func),
vec![
resolved_lambda_var(
"v",
DataType::new_list(DataType::Int32, true),
true,
),
lit(0),
lambda(
["acc2", "v"],
resolved_lambda_var("acc2", DataType::Float64, true)
+ resolved_lambda_var(
"acc1",
DataType::Float64,
true,
)
+ resolved_lambda_var("v", DataType::Int32, true),
),
lambda(
["reduced"],
resolved_lambda_var(
"reduced",
DataType::Float64,
true,
) * lit(2.0),
),
],
)),
),
lambda(
["reduced"],
resolved_lambda_var("reduced", DataType::Float64, true) * lit(2),
),
],
));
assert_eq!(resolved_expr, expected);
}
fn resolved_lambda_var(name: &str, dt: DataType, nullable: bool) -> Expr {
Expr::LambdaVariable(LambdaVariable::new(
name.into(),
Some(Arc::new(Field::new(name, dt, nullable))),
))
}
}