Skip to main content

datafusion_expr/
higher_order_function.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`HigherOrderUDF`]: User Defined Higher Order Functions
19
20use crate::expr::{
21    HigherOrderFunction, display_comma_separated,
22    schema_name_from_exprs_comma_separated_without_space,
23};
24use crate::type_coercion::functions::value_fields_with_higher_order_udf;
25use crate::udf_eq::UdfEq;
26use crate::{ColumnarValue, Documentation, Expr, ExprSchemable};
27use arrow::array::{ArrayRef, RecordBatch};
28use arrow::datatypes::{DataType, FieldRef, Schema};
29use arrow_schema::SchemaRef;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::datatype::FieldExt;
32use datafusion_common::hash_map::EntryRef;
33use datafusion_common::tree_node::{
34    Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
35};
36use datafusion_common::{
37    DFSchema, HashMap, HashSet, Result, ScalarValue, exec_err, internal_datafusion_err,
38    internal_err, not_impl_err, plan_datafusion_err, plan_err,
39};
40use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
41use datafusion_expr_common::signature::Volatility;
42use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
43use std::any::Any;
44use std::cmp::Ordering;
45use std::fmt::Debug;
46use std::hash::{Hash, Hasher};
47use std::mem;
48use std::sync::Arc;
49
50/// The types of arguments for which a function has implementations.
51///
52/// [`HigherOrderTypeSignature`] **DOES NOT** define the types that a user query could call the
53/// function with. DataFusion will automatically coerce (cast) argument types to
54/// one of the supported function signatures, if possible.
55///
56/// # Overview
57/// Functions typically provide implementations for a small number of different
58/// argument [`DataType`]s, rather than all possible combinations. If a user
59/// calls a function with arguments that do not match any of the declared types,
60/// DataFusion will attempt to automatically coerce (add casts to) function
61/// arguments so they match the [`HigherOrderTypeSignature`]. See the [`type_coercion`] module
62/// for more details
63///
64/// [`type_coercion`]: crate::type_coercion
65#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
66pub enum HigherOrderTypeSignature {
67    /// The acceptable signature and coercions rules are special for this
68    /// function.
69    ///
70    /// If this signature is specified,
71    /// DataFusion will call [`HigherOrderUDFImpl::coerce_value_types`] to prepare argument types.
72    UserDefined,
73    /// One or more lambdas or arguments with arbitrary types
74    VariadicAny,
75    /// The specified number of lambdas or arguments with arbitrary types.
76    Any(usize),
77    /// Exactly the specified arguments in the given order, with arbitrary types.
78    /// DataFusion will call [`HigherOrderUDFImpl::coerce_value_types`] to prepare the value
79    /// argument types.
80    Exact(Vec<ValueOrLambda<(), ()>>),
81}
82
83/// Provides information necessary for calling a higher order function.
84///
85/// - [`HigherOrderTypeSignature`] defines the argument types that a function has implementations
86///   for.
87///
88/// - [`Volatility`] defines how the output of the function changes with the input.
89#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
90pub struct HigherOrderSignature {
91    /// The data types that the function accepts. See [HigherOrderTypeSignature] for more information.
92    pub type_signature: HigherOrderTypeSignature,
93    /// The volatility of the function. See [Volatility] for more information.
94    pub volatility: Volatility,
95    /// The max number of times to call [HigherOrderUDFImpl::lambda_parameters] before raising an error.
96    /// Used to guard against implementations that causes an infinite loop by endlessly returning
97    /// [LambdaParametersProgress::Partial]. Defaults to 256
98    pub lambda_parameters_max_iterations: usize,
99}
100
101const LAMBDA_PARAMETERS_MAX_ITERATIONS: usize = 256;
102
103impl HigherOrderSignature {
104    /// Creates a new `HigherOrderSignature` from a given type signature and volatility.
105    pub fn new(type_signature: HigherOrderTypeSignature, volatility: Volatility) -> Self {
106        HigherOrderSignature {
107            type_signature,
108            volatility,
109            lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
110        }
111    }
112
113    /// User-defined coercion rules for the function.
114    pub fn user_defined(volatility: Volatility) -> Self {
115        Self {
116            type_signature: HigherOrderTypeSignature::UserDefined,
117            volatility,
118            lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
119        }
120    }
121
122    /// An arbitrary number of lambdas or arguments of any type.
123    pub fn variadic_any(volatility: Volatility) -> Self {
124        Self {
125            type_signature: HigherOrderTypeSignature::VariadicAny,
126            volatility,
127            lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
128        }
129    }
130
131    /// A specified number of arguments of any type
132    pub fn any(arg_count: usize, volatility: Volatility) -> Self {
133        Self {
134            type_signature: HigherOrderTypeSignature::Any(arg_count),
135            volatility,
136            lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
137        }
138    }
139
140    /// Exactly the specified arguments in the given order, with arbitrary types.
141    /// DataFusion will call [`HigherOrderUDFImpl::coerce_value_types`] to prepare the value
142    /// argument types.
143    ///
144    /// # Example
145    /// A function that takes one value argument followed by one lambda:
146    /// ```
147    /// # use datafusion_expr::{HigherOrderSignature, ValueOrLambda, Volatility};
148    /// let sig = HigherOrderSignature::exact(
149    ///     vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
150    ///     Volatility::Immutable,
151    /// );
152    /// ```
153    pub fn exact(args: Vec<ValueOrLambda<(), ()>>, volatility: Volatility) -> Self {
154        Self {
155            type_signature: HigherOrderTypeSignature::Exact(args),
156            volatility,
157            lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
158        }
159    }
160}
161
162impl PartialEq for dyn HigherOrderUDFImpl {
163    fn eq(&self, other: &Self) -> bool {
164        self.dyn_eq(other as _)
165    }
166}
167
168impl PartialOrd for dyn HigherOrderUDFImpl {
169    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
170        let mut cmp = self.name().cmp(other.name());
171        if cmp == Ordering::Equal {
172            cmp = self.signature().partial_cmp(other.signature())?;
173        }
174        if cmp == Ordering::Equal {
175            cmp = self.aliases().partial_cmp(other.aliases())?;
176        }
177        // Contract for PartialOrd and PartialEq consistency requires that
178        // a == b if and only if partial_cmp(a, b) == Some(Equal).
179        if cmp == Ordering::Equal && self != other {
180            // Functions may have other properties besides name and signature
181            // that differentiate two instances (e.g. type, or arbitrary parameters).
182            // We cannot return Some(Equal) in such case.
183            return None;
184        }
185        debug_assert!(
186            cmp == Ordering::Equal || self != other,
187            "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
188            The functions compare as equal, but they are not equal based on general properties that \
189            the PartialOrd implementation observes,",
190            self.name(),
191            other.name()
192        );
193        Some(cmp)
194    }
195}
196
197impl Eq for dyn HigherOrderUDFImpl {}
198
199impl Hash for dyn HigherOrderUDFImpl {
200    fn hash<H: Hasher>(&self, state: &mut H) {
201        self.dyn_hash(state)
202    }
203}
204
205/// Arguments passed to [`HigherOrderUDFImpl::invoke_with_args`] when invoking a
206/// higher order function.
207#[derive(Debug, Clone)]
208pub struct HigherOrderFunctionArgs {
209    /// The evaluated arguments and lambdas to the function
210    pub args: Vec<ValueOrLambda<ColumnarValue, LambdaArgument>>,
211    /// Field associated with each arg, if it exists
212    /// For lambdas, it will be the field of the result of
213    /// the lambda if evaluated with the parameters
214    /// returned from [`HigherOrderUDFImpl::lambda_parameters`]
215    pub arg_fields: Vec<ValueOrLambda<FieldRef, FieldRef>>,
216    /// The number of rows in record batch being evaluated
217    pub number_rows: usize,
218    /// The return field of the higher order function returned
219    /// (from `return_field_from_args`) when creating the
220    /// physical expression from the logical expression
221    pub return_field: FieldRef,
222    /// The config options at execution time
223    pub config_options: Arc<ConfigOptions>,
224}
225
226impl HigherOrderFunctionArgs {
227    /// The return type of the function. See [`Self::return_field`] for more
228    /// details.
229    pub fn return_type(&self) -> &DataType {
230        self.return_field.data_type()
231    }
232}
233
234/// A lambda argument to a HigherOrderFunction
235#[derive(Clone, Debug)]
236pub struct LambdaArgument {
237    /// The parameters defined in this lambda
238    ///
239    /// For example, for `array_transform([2], v -> -v)`,
240    /// this will be `vec![Field::new("v", DataType::Int32, true)]`
241    params: Vec<FieldRef>,
242    /// The body of the lambda
243    ///
244    /// For example, for `array_transform([2], v -> -v)`,
245    /// this will be the physical expression of `-v`
246    body: Arc<dyn PhysicalExpr>,
247    /// Cached schema built from `params`. Reused across every `evaluate` call
248    /// (and across every nested-list iteration when the lambda is called once
249    /// per outer sublist), avoiding the per-call `Schema::new` build that
250    /// includes constructing the internal name -> index map.
251    schema: SchemaRef,
252    /// A RecordBatch containing the captured columns inside this lambda body, if any
253    ///
254    /// For example, for `array_transform([2], v -> v + a + b)`,
255    /// this will be a `RecordBatch` with two columns, `a` and `b`
256    captures: Option<RecordBatch>,
257}
258
259impl LambdaArgument {
260    pub fn new(
261        params: Vec<FieldRef>,
262        body: Arc<dyn PhysicalExpr>,
263        captures: Option<RecordBatch>,
264    ) -> Self {
265        let fields = match &captures {
266            Some(batch) => batch
267                .schema_ref()
268                .fields()
269                .iter()
270                .cloned()
271                .chain(params.clone())
272                .collect(),
273            None => params.clone(),
274        };
275
276        let schema = Arc::new(Schema::new(fields));
277
278        Self {
279            params,
280            body,
281            schema,
282            captures,
283        }
284    }
285
286    /// Evaluate this lambda
287    /// `args` should evaluate to the value of each parameter
288    /// of the correspondent lambda returned in [HigherOrderUDFImpl::lambda_parameters].
289    ///
290    /// `spread_captures` is responsible for transforming the captured column arrays
291    /// so they align with the evaluation batch. Captures are snapshotted from the
292    /// outer batch at construction time, giving one value per outer row, but the
293    /// function may evaluate the lambda body over a batch with a different number
294    /// of rows. It is the function's responsibility to provide the appropriate
295    /// `spread_captures` closure to expand (or otherwise reshape) the captures
296    /// to match.
297    ///
298    /// Taking as an example the following table:
299    ///
300    /// ```sql
301    /// CREATE TABLE t (arr INT[], a INT) AS VALUES
302    ///   ([1, 2, 3], 10),
303    ///   ([],        20),
304    ///   ([4],       30);
305    /// ```
306    ///
307    /// `SELECT array_transform(arr, v -> v + a) from t` would execute over three outer rows:
308    ///
309    /// ```text
310    /// arr (ListArray):  [[1, 2, 3], [], [4]]   -- 3 outer rows, 4 total elements
311    /// a   (captured):   [10,        20,  30]   -- one value per outer row
312    /// ```
313    ///
314    /// `array_transform` flattens the list elements into a single batch of 4 rows,
315    /// so `spread_captures` must repeat/drop captured values to match:
316    ///
317    /// ```text
318    /// v (flattened args): [1,  2,  3,  4]
319    /// a (spread):         [10, 10, 10, 30]  -- 10 repeated for 3 elements in row 0,
320    ///                                        -- 20 dropped for the empty sublist in row 1,
321    ///                                        -- 30 once for the single element in row 2
322    /// ```
323    ///
324    /// The lambda body `v + a` then evaluates element-wise over these 4-row arrays,
325    /// producing `[11, 12, 13, 34]`, which `array_transform` reassembles into `[[11, 12, 13], [], [34]]`.
326    ///
327    /// If the lambda has no captures, `spread_captures` is never called.
328    pub fn evaluate(
329        &self,
330        args: &[&dyn Fn() -> Result<ArrayRef>],
331        spread_captures: impl FnOnce(&[ArrayRef]) -> Result<Vec<ArrayRef>>,
332    ) -> Result<ColumnarValue> {
333        let spread_captures = self
334            .captures
335            .as_ref()
336            .map(|captures| {
337                let spread_columns = spread_captures(captures.columns())?;
338
339                RecordBatch::try_new(captures.schema(), spread_columns)
340            })
341            .transpose()?;
342
343        let merged = merge_captures_with_variables(
344            spread_captures.as_ref(),
345            Arc::clone(&self.schema),
346            &self.params,
347            args,
348        )?;
349
350        self.body.evaluate(&merged)
351    }
352}
353
354fn merge_captures_with_variables(
355    captures: Option<&RecordBatch>,
356    schema: SchemaRef,
357    params: &[FieldRef],
358    variables: &[&dyn Fn() -> Result<ArrayRef>],
359) -> Result<RecordBatch> {
360    if variables.len() < params.len() {
361        return exec_err!(
362            "expected at least {} lambda arguments to merge with captures, got {}",
363            params.len(),
364            variables.len()
365        );
366    }
367
368    let columns = match captures {
369        Some(captures) => {
370            let mut columns = captures.columns().to_vec();
371
372            for arg in &variables[..params.len()] {
373                columns.push(arg()?);
374            }
375
376            columns
377        }
378        None => variables
379            .iter()
380            .take(params.len())
381            .map(|arg| arg())
382            .collect::<Result<_>>()?,
383    };
384
385    Ok(RecordBatch::try_new(schema, columns)?)
386}
387
388/// Information about arguments passed to the function
389///
390/// This structure contains metadata about how the function was called
391/// such as the type of the arguments, any scalar arguments and if the
392/// arguments can (ever) be null
393///
394/// See [`HigherOrderUDFImpl::return_field_from_args`] for more information
395#[derive(Clone, Debug)]
396pub struct HigherOrderReturnFieldArgs<'a> {
397    /// The data types of the arguments to the function
398    ///
399    /// If argument `i` to the function is a lambda, it will be the field of the result of the
400    /// lambda if evaluated with the parameters returned from [`HigherOrderUDFImpl::lambda_parameters`]
401    ///
402    /// For example, with `array_transform([1], v -> v == 5)`
403    /// this field will be
404    /// ```ignore
405    /// [
406    ///     ValueOrLambda::Value(Field::new("", DataType::new_list(DataType::Int32, true), true)),
407    ///     ValueOrLambda::Lambda(Field::new("", DataType::Boolean, true))
408    /// ]
409    /// ```
410    pub arg_fields: &'a [ValueOrLambda<FieldRef, FieldRef>],
411    /// Is argument `i` to the function a scalar (constant)?
412    ///
413    /// If the argument `i` is not a scalar, it will be None
414    ///
415    /// For example, if a function is called like `array_transform([1], v -> v == 5)`
416    /// this field will be `[Some(ScalarValue::List(...), None]`
417    pub scalar_arguments: &'a [Option<&'a ScalarValue>],
418}
419
420/// An argument to a higher order function
421#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)]
422pub enum ValueOrLambda<V, L> {
423    /// A value with associated data
424    Value(V),
425    /// A lambda with associated data
426    Lambda(L),
427}
428
429/// Represents a step during the resolution of the parameters of all lambdas of a given
430/// higher-order function via [HigherOrderUDFImpl::lambda_parameters]. It's valid that the
431/// fields of a given lambda changes between steps, and is up to the implementation to
432/// provide during the function evaluation the parameters that matches the fields returned
433/// at the [LambdaParametersProgress::Complete] step. See [HigherOrderUDFImpl::lambda_parameters]
434/// docs for more details
435pub enum LambdaParametersProgress {
436    /// The parameters of some lambdas are unknown due to a dependency on another lambda output field
437    /// or are placeholders due to a dependency on it's own output field. It's perfectly valid to
438    /// contain only `Some`'s and not a single `None`, representing lambdas that depends only on itself
439    /// and not on others. [HigherOrderUDFImpl::lambda_parameters] will be called again with the output
440    /// field of all lambdas with known parameters.
441    Partial(Vec<Option<Vec<FieldRef>>>),
442    /// There are no unmet dependencies and all parameters are known, [HigherOrderUDFImpl::lambda_parameters]
443    /// will not be called again
444    Complete(Vec<Vec<FieldRef>>),
445}
446
447/// Trait for implementing user defined higher order functions.
448///
449/// This trait exposes the full API for implementing user defined functions and
450/// can be used to implement any function.
451///
452/// New higher order functions typically implement this trait and are then
453/// wrapped in a [`HigherOrderUDF`] for registration with DataFusion.
454///
455/// See [`array_transform.rs`] for a commented complete implementation
456///
457/// [`array_transform.rs`]: https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/array_transform.rs
458pub trait HigherOrderUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any {
459    /// Returns this function's name
460    fn name(&self) -> &str;
461
462    /// Returns any aliases (alternate names) for this function.
463    ///
464    /// Aliases can be used to invoke the same function using different names.
465    /// For example in some databases `now()` and `current_timestamp()` are
466    /// aliases for the same function. This behavior can be obtained by
467    /// returning `current_timestamp` as an alias for the `now` function.
468    ///
469    /// Note: `aliases` should only include names other than [`Self::name`].
470    /// Defaults to `[]` (no aliases)
471    fn aliases(&self) -> &[String] {
472        &[]
473    }
474
475    /// Returns the name of the column this expression would create
476    ///
477    /// See [`Expr::schema_name`] for details
478    fn schema_name(&self, args: &[Expr]) -> Result<String> {
479        Ok(format!(
480            "{}({})",
481            self.name(),
482            schema_name_from_exprs_comma_separated_without_space(args)?
483        ))
484    }
485
486    /// Returns a [`HigherOrderSignature`] describing the argument types for which this
487    /// function has an implementation, and the function's [`Volatility`].
488    ///
489    /// See [`HigherOrderSignature`] for more details on argument type handling
490    /// and [`Self::return_field_from_args`] for computing the return type.
491    ///
492    /// [`Volatility`]: datafusion_expr_common::signature::Volatility
493    fn signature(&self) -> &HigherOrderSignature;
494
495    /// Return the field of all the parameters supported by the lambdas in `fields`.
496    /// If a lambda support multiple parameters, all should be returned, regardless of
497    /// whether they are used or not on a particular invocation
498    ///
499    /// Tip: If you have a [`HigherOrderFunction`] invocation, you can call the helper
500    /// [`HigherOrderFunction::lambda_parameters`] instead of this method directly
501    ///
502    /// The name of the returned fields are ignored.
503    ///
504    /// This function is repeatedelly called until [LambdaParametersProgress::Complete] is returned, with
505    /// `step` increased by one at each invocation, starting at 0.
506    ///
507    /// For functions which all lambda parameters depend only on the field of it's value arguments,
508    /// this can return [LambdaParametersProgress::Complete] at step 0. Taking as an example a strict
509    /// array_reduce with the signature `(arr: [V], initial_value: I, (I, V) -> I, (I) -> O) -> O`, which
510    /// requires it's initial value to be the exact same type of it's merge output, which is also the
511    /// parameter of it's finish lambda, the expression
512    ///
513    /// `array_reduce([1.2, 2.1], 0.0, (acc, v) -> acc + v + 1.5, v -> v > 5.1)`
514    ///
515    ///  would result in this function being called as the following:
516    ///
517    /// ```ignore
518    /// let lambda_parameters = array_reduce.lambda_parameters(
519    ///     0,
520    ///     &[
521    ///         // the Field of the literal `[1.2, 2.1]`, the array being reduced
522    ///         ValueOrLambda::Value(Arc::new(Field::new("", DataType::new_list(DataType::Float32, true), true))),
523    ///         // the Field of the literal `0.0`, the initial value
524    ///         ValueOrLambda::Value(Arc::new(Field::new("", DataType::Float32, true))),
525    ///         // the Field of the output of the merge lambda, which is unknown at this point because it depends
526    ///         // on the return of this call
527    ///         ValueOrLambda::Lambda(None),
528    ///         // the Field of the output of the finish lambda, unknown for the same reason as above
529    ///         ValueOrLambda::Lambda(None),
530    /// ])?;
531    ///
532    /// assert_eq!(
533    ///      lambda_parameters,
534    ///      LambdaParametersProgress::Complete(vec![
535    ///         // the finish lambda supported parameters, regardless of how many are actually used
536    ///         vec![
537    ///             // the accumulator which is the field of the initial value
538    ///             Arc::new(Field::new("ignored_name", DataType::Float32, true)),
539    ///             // the array values being reduced
540    ///             Arc::new(Field::new("", DataType::Float32, true)),
541    ///         ],
542    ///         // the merge lambda supported parameters
543    ///         vec![
544    ///             // the reduced value which is the field of the initial value
545    ///             Arc::new(Field::new("ignored_name", DataType::Float32, true)),
546    ///         ],
547    ///      ])
548    /// );
549    /// ```
550    ///
551    /// For functions which lambda parameters depends on the output of other lambdas, or on their own lambda,
552    /// this can return [LambdaParametersProgress::Partial] until all dependencies are met. Note that for
553    /// lambda with cyclic dependencies, you likely want to use [HigherOrderUDFImpl::coerce_values_for_lambdas] too.
554    /// Take as an example a flexible array_reduce with the signature `(arr: [V], initial_value: I, (ACC, V) -> ACC, (ACC) -> O) -> O`.
555    /// It has a cyclic dependency in the merge lambda, and a dependency of the finish lambda in the merge lambda,
556    /// and only requires the initial value to be *coercible* to the output of the merge lambda, which is defined by
557    /// it's [HigherOrderUDFImpl::coerce_values_for_lambdas] implementation. The expression
558    ///
559    /// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 5.1)`
560    ///
561    /// would result in this function being called as the following:
562    ///
563    /// ```ignore
564    /// let lambda_parameters = array_reduce.lambda_parameters(
565    ///     0,
566    ///     &[
567    ///         // the Field of the literal `[1.2, 2.1]`, the array being reduced
568    ///         ValueOrLambda::Value(Arc::new(Field::new("", DataType::new_list(DataType::Float32, true), true))),
569    ///         // the Field of the literal `0`, the initial value
570    ///         ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, true))),
571    ///         // the Field of the output of the merge lambda, which is unknown at this point because it depends on
572    ///         // the return this call
573    ///         ValueOrLambda::Lambda(None),
574    ///         // the Field of the output of the finish lambda, unknown for the same reason as above
575    ///         ValueOrLambda::Lambda(None),
576    /// ])?;
577    ///
578    /// assert_eq!(
579    ///      lambda_parameters,
580    ///      LambdaParametersProgress::Partial(vec![
581    ///         // the finish lambda supported parameters, regardless of how many are actually used
582    ///         Some(vec![
583    ///             // at step 0, use the field of the initial value
584    ///             Arc::new(Field::new("ignored_name", DataType::Int32, true)),
585    ///             // the array values being reduced
586    ///             Arc::new(Field::new("", DataType::Float32, true)),
587    ///         ]),
588    ///         // the merge lambda supported parameters, unknown at this point due to dependency on the merge output
589    ///         None,
590    ///      ])
591    /// );
592    ///
593    /// let lambda_parameters = array_reduce.lambda_parameters(
594    ///     1,
595    ///     &[
596    ///         // the Field of the literal `[1.2, 2.1]`, the array being reduced
597    ///         ValueOrLambda::Value(Arc::new(Field::new("", DataType::new_list(DataType::Float32, true), true))),
598    ///         // the Field of the literal `0`, the initial value
599    ///         ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, true))),
600    ///         // the Field of the output of the merge lambda, which could be inferred to be a Float32 based on the
601    ///         // returned values of the previous step
602    ///         ValueOrLambda::Value(Arc::new(Field::new("", DataType::Float32, true))),
603    ///         // the Field of the output of the finish lambda, which is unknown at this point because it depends
604    ///         // on the return of this call
605    ///         ValueOrLambda::Lambda(None),
606    /// ])?;
607    ///
608    /// assert_eq!(
609    ///      lambda_parameters,
610    ///      LambdaParametersProgress::Complete(vec![
611    ///         // the finish lambda supported parameters, regardless of how many are actually used
612    ///         vec![
613    ///             // the finish lambda own output now used as it's accumulator
614    ///             Arc::new(Field::new("ignored_name", DataType::Float32, true)),
615    ///             // the array values being reduced
616    ///             Arc::new(Field::new("", DataType::Float32, true)),
617    ///         ],
618    ///         // the merge lambda supported parameters, which is the output of the merge lambda,
619    ///         vec![
620    ///             // the output of the merge lambda
621    ///             Arc::new(Field::new("", DataType::Float32, true)),
622    ///         ],
623    ///      ])
624    /// );
625    ///
626    /// let coerce_to = array_reduce.coerce_values_for_lambdas(&[
627    ///     // the literal `[1.2, 2.1]` data type, the array being reduced
628    ///     ValueOrLambda::Value(DataType::new_list(DataType::Float32, true)),
629    ///     // the literal `0` data type, the initial value
630    ///     ValueOrLambda::Value(DataType::Int32),
631    ///     // the output data type of the merge lambda
632    ///     ValueOrLambda::Lambda(DataType::Float32),
633    ///     // the output data type of the finish lambda
634    ///     ValueOrLambda::Lambda(DataType::Boolean),
635    /// ])?;
636    ///
637    /// assert_eq!(
638    ///     coerce_to,
639    ///     Some(vec![
640    ///         // return the same type for the array being reduced
641    ///         DataType::new_list(DataType::Float32, true),
642    ///         // coerce the initial value to the output of the merge lambda
643    ///         DataType::Float32,
644    ///     ])
645    /// );
646    ///
647    /// ```
648    ///
649    /// Note this may also be called at step 0 with all lambda outputs already set, and in that case,
650    /// [LambdaParametersProgress::Complete] must be returned
651    ///
652    /// The implementation can assume that some other part of the code has coerced
653    /// the actual argument types to match [`Self::signature`], except the coercion defined by
654    /// [Self::coerce_values_for_lambdas].
655    ///
656    /// [`HigherOrderFunction`]: crate::expr::HigherOrderFunction
657    /// [`HigherOrderFunction::lambda_parameters`]: crate::expr::HigherOrderFunction::lambda_parameters
658    fn lambda_parameters(
659        &self,
660        step: usize,
661        fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
662    ) -> Result<LambdaParametersProgress>;
663
664    /// Coerce value arguments of a function call to types that the function can evaluate also taking into
665    /// account the *output type of it's lambdas*. This differs from [HigherOrderUDFImpl::coerce_value_types]
666    /// that only has access to the type of it's value arguments because it's called before the output type
667    /// of lambdas are known.
668    ///
669    /// See the [type coercion module](crate::type_coercion)
670    /// documentation for more details on type coercion
671    ///
672    /// # Parameters
673    /// * `fields`: The argument types of the value arguments of this function, or the output type of lambdas
674    ///
675    /// # Return value
676    /// If `Some`, contains a Vec with the same number of [ValueOrLambda::Value] in `fields`.
677    /// DataFusion will `CAST` the function call arguments to these specific types. If `None`, no
678    /// coercion will be applied beyond the one defined by the function signature.
679    ///
680    /// For example, a flexible array_reduce implementation (see [Self::lambda_parameters] docs), when working
681    /// with the expression below, may want to coerce it's initial value argument, the *integer* `0`,
682    /// to match the output of it's merge function, which is a *float*:
683    ///
684    /// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 2.0)`
685    fn coerce_values_for_lambdas(
686        &self,
687        _fields: &[ValueOrLambda<DataType, DataType>],
688    ) -> Result<Option<Vec<DataType>>> {
689        Ok(None)
690    }
691
692    /// What type will be returned by this function, given the arguments?
693    ///
694    /// The implementation can assume that some other part of the code has coerced
695    /// the actual argument types to match [`Self::signature`], including the coercion
696    /// defined by [Self::coerce_values_for_lambdas].
697    ///
698    /// # Example creating `Field`
699    ///
700    /// Note the name of the `Field` is ignored, except for structured types such as
701    /// `DataType::Struct`.
702    ///
703    /// ```rust
704    /// # use std::sync::Arc;
705    /// # use arrow::datatypes::{DataType, Field, FieldRef};
706    /// # use datafusion_common::Result;
707    /// # use datafusion_expr::HigherOrderReturnFieldArgs;
708    /// # struct Example{}
709    /// # impl Example {
710    /// fn return_field_from_args(&self, args: HigherOrderReturnFieldArgs) -> Result<FieldRef> {
711    ///     let field = Arc::new(Field::new("ignored_name", DataType::Int32, true));
712    ///     Ok(field)
713    /// }
714    /// # }
715    /// ```
716    fn return_field_from_args(
717        &self,
718        args: HigherOrderReturnFieldArgs,
719    ) -> Result<FieldRef>;
720
721    /// Whether List or LargeList arguments should have it's non-empty null
722    /// sublists cleaned with [remove_list_null_values] before invoking this function
723    ///
724    /// The default implementation always returns true and should only be implemented
725    /// if you want to handle non-empty null sublists yourself
726    ///
727    /// [remove_list_null_values]: datafusion_common::utils::remove_list_null_values
728    // todo: extend this to listview and maps when remove_list_null_values supports it
729    fn clear_null_values(&self) -> bool {
730        true
731    }
732
733    /// Invoke the function returning the appropriate result.
734    ///
735    /// # Performance
736    ///
737    /// For the best performance, the implementations should handle the common case
738    /// when one or more of their arguments are constant values (aka
739    /// [`ColumnarValue::Scalar`]).
740    ///
741    /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
742    /// to arrays, which will likely be simpler code, but be slower.
743    fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue>;
744
745    /// Returns true if some of this `exprs` subexpressions may not be evaluated
746    /// and thus any side effects (like divide by zero) may not be encountered.
747    ///
748    /// Setting this to true prevents certain optimizations such as common
749    /// subexpression elimination
750    ///
751    /// When overriding this function to return `true`, [HigherOrderUDFImpl::conditional_arguments] can also be
752    /// overridden to report more accurately which arguments are eagerly evaluated and which ones
753    /// lazily.
754    fn short_circuits(&self) -> bool {
755        false
756    }
757
758    /// Determines which of the arguments passed to *this higher-order function*
759    /// are evaluated eagerly and which may be evaluated lazily. Note that this
760    /// does *not* applies to the arguments that *lambda functions* pass to it's
761    /// body expression
762    ///
763    /// If this function returns `None`, all arguments are eagerly evaluated.
764    /// Returning `None` is a micro optimization that saves a needless `Vec`
765    /// allocation.
766    ///
767    /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager`
768    /// are the arguments that are always evaluated, and `lazy` are the
769    /// arguments that may be evaluated lazily (i.e. may not be evaluated at all
770    /// in some cases).
771    ///
772    /// Implementations must ensure that the two returned `Vec`s are disjunct,
773    /// and that each argument from `args` is present in one the two `Vec`s.
774    ///
775    /// When overriding this function, [HigherOrderUDFImpl::short_circuits] must
776    /// be overridden to return `true`.
777    fn conditional_arguments<'a>(
778        &self,
779        args: &'a [Expr],
780    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
781        if self.short_circuits() {
782            Some((vec![], args.iter().collect()))
783        } else {
784            None
785        }
786    }
787
788    /// Coerce value arguments of a function call to types that the function can evaluate.
789    /// Note that if you need to coerce values based on the output type of lambdas, you
790    /// must use [HigherOrderUDFImpl::coerce_values_for_lambdas], as this function is used before
791    /// the output type of lambdas are known
792    ///
793    /// See the [type coercion module](crate::type_coercion)
794    /// documentation for more details on type coercion
795    ///
796    /// For example, if your function requires a contiguous list argument, but the user calls
797    /// it like `my_func(c, v -> v+2)` (i.e. with `c` as a ListView), coerce_types can return `[DataType::List(..)]`
798    /// to ensure the argument is converted to a List
799    ///
800    /// # Parameters
801    /// * `arg_types`: The argument types of the value arguments of this function, excluding lambdas
802    ///
803    /// # Return value
804    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
805    /// arguments to these specific types.
806    fn coerce_value_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
807        not_impl_err!(
808            "Function {} does not implement coerce_value_types",
809            self.name()
810        )
811    }
812
813    /// Returns the documentation for this function.
814    ///
815    /// Documentation can be accessed programmatically as well as generating
816    /// publicly facing documentation.
817    fn documentation(&self) -> Option<&Documentation> {
818        None
819    }
820}
821
822/// Logical representation of a Higher Order User Defined Function.
823///
824/// A higher order function takes one or more lambda arguments in addition to
825/// regular value arguments. This struct contains the information DataFusion
826/// needs to plan and invoke functions you supply such as name, type signature,
827/// return type, and actual implementation.
828#[derive(Debug, Clone)]
829pub struct HigherOrderUDF {
830    inner: Arc<dyn HigherOrderUDFImpl>,
831}
832
833impl PartialEq for HigherOrderUDF {
834    fn eq(&self, other: &Self) -> bool {
835        self.inner.as_ref().dyn_eq(other.inner.as_ref())
836    }
837}
838
839impl PartialOrd for HigherOrderUDF {
840    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
841        let mut cmp = self.name().cmp(other.name());
842        if cmp == Ordering::Equal {
843            cmp = self.signature().partial_cmp(other.signature())?;
844        }
845        if cmp == Ordering::Equal {
846            cmp = self.aliases().partial_cmp(other.aliases())?;
847        }
848        // Contract for PartialOrd and PartialEq consistency requires that
849        // a == b if and only if partial_cmp(a, b) == Some(Equal).
850        if cmp == Ordering::Equal && self != other {
851            // Functions may have other properties besides name and signature
852            // that differentiate two instances (e.g. type, or arbitrary parameters).
853            // We cannot return Some(Equal) in such case.
854            return None;
855        }
856        debug_assert!(
857            cmp == Ordering::Equal || self != other,
858            "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
859            The functions compare as equal, but they are not equal based on general properties that \
860            the PartialOrd implementation observes,",
861            self.name(),
862            other.name()
863        );
864        Some(cmp)
865    }
866}
867
868impl Eq for HigherOrderUDF {}
869
870impl Hash for HigherOrderUDF {
871    fn hash<H: Hasher>(&self, state: &mut H) {
872        self.inner.dyn_hash(state)
873    }
874}
875
876impl HigherOrderUDF {
877    /// Create a new `HigherOrderUDF` from a [`HigherOrderUDFImpl`] trait object.
878    ///
879    /// Note this is the same as using the `From` impl (`HigherOrderUDF::from`).
880    pub fn new_from_impl<F>(fun: F) -> HigherOrderUDF
881    where
882        F: HigherOrderUDFImpl + 'static,
883    {
884        Self::new_from_shared_impl(Arc::new(fun))
885    }
886
887    /// Create a new `HigherOrderUDF` from a shared [`HigherOrderUDFImpl`] trait object.
888    pub fn new_from_shared_impl(fun: Arc<dyn HigherOrderUDFImpl>) -> HigherOrderUDF {
889        Self { inner: fun }
890    }
891
892    /// Return the underlying [`HigherOrderUDFImpl`] trait object for this function.
893    pub fn inner(&self) -> &Arc<dyn HigherOrderUDFImpl> {
894        &self.inner
895    }
896
897    /// Adds additional names that can be used to invoke this function, in
898    /// addition to `name`.
899    ///
900    /// If you implement [`HigherOrderUDFImpl`] directly you should return aliases
901    /// directly.
902    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
903        Self::new_from_impl(AliasedHigherOrderUDFImpl::new(
904            Arc::clone(&self.inner),
905            aliases,
906        ))
907    }
908
909    /// Returns this function's name.
910    ///
911    /// See [`HigherOrderUDFImpl::name`] for more details.
912    pub fn name(&self) -> &str {
913        self.inner.name()
914    }
915
916    /// Returns the aliases for this function.
917    ///
918    /// See [`HigherOrderUDF::with_aliases`] for more details.
919    pub fn aliases(&self) -> &[String] {
920        self.inner.aliases()
921    }
922
923    /// Returns this function's schema_name.
924    ///
925    /// See [`HigherOrderUDFImpl::schema_name`] for more details.
926    pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
927        self.inner.schema_name(args)
928    }
929
930    /// Returns this function's [`HigherOrderSignature`].
931    pub fn signature(&self) -> &HigherOrderSignature {
932        self.inner.signature()
933    }
934
935    /// Returns the parameters of all lambdas of this function for the current step.
936    ///
937    /// See [`HigherOrderUDFImpl::lambda_parameters`] for more details.
938    pub fn lambda_parameters(
939        &self,
940        step: usize,
941        fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
942    ) -> Result<LambdaParametersProgress> {
943        self.inner.lambda_parameters(step, fields)
944    }
945
946    /// Coerce value arguments based on lambda output types.
947    ///
948    /// See [`HigherOrderUDFImpl::coerce_values_for_lambdas`] for more details.
949    pub fn coerce_values_for_lambdas(
950        &self,
951        fields: &[ValueOrLambda<DataType, DataType>],
952    ) -> Result<Option<Vec<DataType>>> {
953        self.inner.coerce_values_for_lambdas(fields)
954    }
955
956    /// Returns the return field of the function given its arguments.
957    ///
958    /// See [`HigherOrderUDFImpl::return_field_from_args`] for more details.
959    pub fn return_field_from_args(
960        &self,
961        args: HigherOrderReturnFieldArgs,
962    ) -> Result<FieldRef> {
963        self.inner.return_field_from_args(args)
964    }
965
966    /// Whether List or LargeList arguments should have non-empty null sublists
967    /// cleaned before invoking this function.
968    pub fn clear_null_values(&self) -> bool {
969        self.inner.clear_null_values()
970    }
971
972    /// Invoke the function returning the appropriate result.
973    ///
974    /// See [`HigherOrderUDFImpl::invoke_with_args`] for more details.
975    pub fn invoke_with_args(
976        &self,
977        args: HigherOrderFunctionArgs,
978    ) -> Result<ColumnarValue> {
979        self.inner.invoke_with_args(args)
980    }
981
982    /// Returns true if some of this function's subexpressions may not be evaluated.
983    ///
984    /// See [`HigherOrderUDFImpl::short_circuits`] for more details.
985    pub fn short_circuits(&self) -> bool {
986        self.inner.short_circuits()
987    }
988
989    /// Returns which arguments are evaluated eagerly vs lazily.
990    ///
991    /// See [`HigherOrderUDFImpl::conditional_arguments`] for more details.
992    pub fn conditional_arguments<'a>(
993        &self,
994        args: &'a [Expr],
995    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
996        self.inner.conditional_arguments(args)
997    }
998
999    /// Coerce value arguments of a function call to types that the function can evaluate.
1000    ///
1001    /// See [`HigherOrderUDFImpl::coerce_value_types`] for more details.
1002    pub fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1003        self.inner.coerce_value_types(arg_types)
1004    }
1005
1006    /// Returns the documentation for this function, if any.
1007    pub fn documentation(&self) -> Option<&Documentation> {
1008        self.inner.documentation()
1009    }
1010}
1011
1012impl<F> From<F> for HigherOrderUDF
1013where
1014    F: HigherOrderUDFImpl + 'static,
1015{
1016    fn from(fun: F) -> Self {
1017        Self::new_from_impl(fun)
1018    }
1019}
1020
1021/// `HigherOrderUDFImpl` that adds aliases to the underlying function. It is
1022/// better to implement [`HigherOrderUDFImpl`], which supports aliases, directly
1023/// if possible.
1024#[derive(Debug, PartialEq, Eq, Hash)]
1025struct AliasedHigherOrderUDFImpl {
1026    inner: UdfEq<Arc<dyn HigherOrderUDFImpl>>,
1027    aliases: Vec<String>,
1028}
1029
1030impl AliasedHigherOrderUDFImpl {
1031    fn new(
1032        inner: Arc<dyn HigherOrderUDFImpl>,
1033        new_aliases: impl IntoIterator<Item = &'static str>,
1034    ) -> Self {
1035        let mut aliases = inner.aliases().to_vec();
1036        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1037        Self {
1038            inner: inner.into(),
1039            aliases,
1040        }
1041    }
1042}
1043
1044#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method
1045impl HigherOrderUDFImpl for AliasedHigherOrderUDFImpl {
1046    fn name(&self) -> &str {
1047        self.inner.name()
1048    }
1049
1050    fn aliases(&self) -> &[String] {
1051        &self.aliases
1052    }
1053
1054    fn schema_name(&self, args: &[Expr]) -> Result<String> {
1055        self.inner.schema_name(args)
1056    }
1057
1058    fn signature(&self) -> &HigherOrderSignature {
1059        self.inner.signature()
1060    }
1061
1062    fn lambda_parameters(
1063        &self,
1064        step: usize,
1065        fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
1066    ) -> Result<LambdaParametersProgress> {
1067        self.inner.lambda_parameters(step, fields)
1068    }
1069
1070    fn coerce_values_for_lambdas(
1071        &self,
1072        fields: &[ValueOrLambda<DataType, DataType>],
1073    ) -> Result<Option<Vec<DataType>>> {
1074        self.inner.coerce_values_for_lambdas(fields)
1075    }
1076
1077    fn return_field_from_args(
1078        &self,
1079        args: HigherOrderReturnFieldArgs,
1080    ) -> Result<FieldRef> {
1081        self.inner.return_field_from_args(args)
1082    }
1083
1084    fn clear_null_values(&self) -> bool {
1085        self.inner.clear_null_values()
1086    }
1087
1088    fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
1089        self.inner.invoke_with_args(args)
1090    }
1091
1092    fn short_circuits(&self) -> bool {
1093        self.inner.short_circuits()
1094    }
1095
1096    fn conditional_arguments<'a>(
1097        &self,
1098        args: &'a [Expr],
1099    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
1100        self.inner.conditional_arguments(args)
1101    }
1102
1103    fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1104        self.inner.coerce_value_types(arg_types)
1105    }
1106
1107    fn documentation(&self) -> Option<&Documentation> {
1108        self.inner.documentation()
1109    }
1110}
1111
1112pub(crate) fn resolve_lambda_variables(
1113    expr: Expr,
1114    schema: &DFSchema,
1115    // a map of lambda variable name => a never empty stack of fields [ [..shadowed], in_scope ]
1116    vars: &mut HashMap<String, Vec<FieldRef>>,
1117) -> Result<Transformed<Expr>> {
1118    expr.transform_down(|expr| match expr {
1119        Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
1120            // not inlined to reduce nesting
1121            resolve_higher_order_function(func, args, schema, vars)
1122        }
1123        Expr::LambdaVariable(mut var) => {
1124            let field_stack = vars.get(&var.name).ok_or_else(|| {
1125                plan_datafusion_err!(
1126                    "missing field of lambda variable {} while resolving",
1127                    var.name
1128                )
1129            })?;
1130
1131            let field = field_stack.last().ok_or_else(|| {
1132                internal_datafusion_err!("every entry should have at least one field")
1133            })?;
1134
1135            let field = Arc::clone(field).renamed(&var.name);
1136
1137            let transformed = var.field.as_ref().is_none_or(|old| old != &field);
1138
1139            var.field = Some(field);
1140
1141            Ok(Transformed::new_transformed(
1142                Expr::LambdaVariable(var),
1143                transformed,
1144            ))
1145        }
1146        _ => Ok(Transformed::no(expr)),
1147    })
1148}
1149
1150fn resolve_higher_order_function(
1151    func: Arc<HigherOrderUDF>,
1152    args: Vec<Expr>,
1153    schema: &DFSchema,
1154    // a map of lambda variable name => a never empty stack of fields [ [..shadowed], in_scope ]
1155    vars: &mut HashMap<String, Vec<FieldRef>>,
1156) -> Result<Transformed<Expr>> {
1157    let args = if !vars.is_empty() {
1158        /*  if this is a nested lambda, we must resolve non-lambda args before invoking
1159            lambda_parameters because it will invoke ExprSchemable::to_field for every
1160            non-lambda parameter, and if one them contains a lambda variable, it will fail
1161            due to it being unresolved. Example query:
1162
1163            array_transform([[1, 2]], a -> array_transform(a, b -> b+1))
1164
1165            the nested array_transform's lambda_parameters will call Lambdavariable::to_field
1166            on it's first argument, the variable `a`, which must be resolved
1167        */
1168        args.map_elements(|arg| match arg {
1169            Expr::Lambda(_) => Ok(Transformed::no(arg)),
1170            _ => resolve_lambda_variables(arg, schema, vars),
1171        })?
1172    } else {
1173        Transformed::no(args)
1174    };
1175
1176    let transformed = args.transformed;
1177    let mut args = args.data;
1178
1179    let current_fields = args
1180        .iter()
1181        .map(|e| match e {
1182            Expr::Lambda(_lambda_function) => Ok(ValueOrLambda::Lambda(None)),
1183            _ => Ok(ValueOrLambda::Value(e.to_field(schema)?.1)),
1184        })
1185        .collect::<Result<Vec<_>>>()?;
1186
1187    // coerce fields because coercion may alter the lambda parameters
1188    let mut fields = value_fields_with_higher_order_udf(&current_fields, func.as_ref())?;
1189
1190    let num_lambdas = args.iter().filter(|a| matches!(a, Expr::Lambda(_))).count();
1191
1192    let mut step = 0;
1193
1194    let lambda_params = loop {
1195        match func.lambda_parameters(step, &fields)? {
1196            LambdaParametersProgress::Partial(params) => {
1197                let mut params = params.into_iter();
1198
1199                if params.len() != num_lambdas {
1200                    return plan_err!(
1201                        "{} lambda_parameters returned {} lambdas but {num_lambdas} expected",
1202                        func.name(),
1203                        params.len()
1204                    );
1205                }
1206
1207                for (arg, field) in std::iter::zip(&mut args, &mut fields) {
1208                    match (arg, field) {
1209                        (Expr::Lambda(lambda), ValueOrLambda::Lambda(field)) => {
1210                            let params = params.next().ok_or_else(|| {
1211                                internal_datafusion_err!(
1212                                    "params len should have been checked above"
1213                                )
1214                            })?;
1215
1216                            if let Some(params) = params {
1217                                for (name, field) in
1218                                    std::iter::zip(&lambda.params, params)
1219                                {
1220                                    vars.entry_ref(name)
1221                                        .or_default()
1222                                        .push(field.renamed(name.as_str()));
1223                                }
1224
1225                                let body_with_vars = resolve_lambda_variables(
1226                                    mem::take(lambda.body.as_mut()),
1227                                    schema,
1228                                    vars,
1229                                )?;
1230
1231                                remove_scope(vars, &lambda.params)?;
1232
1233                                *field = Some(body_with_vars.data.to_field(schema)?.1);
1234                                *lambda.body = body_with_vars.data;
1235                            }
1236                        }
1237                        (_, ValueOrLambda::Lambda(_)) => {
1238                            return internal_err!(
1239                                "value_fields_with_higher_order_udf returned a value for a lambda argument"
1240                            );
1241                        }
1242                        (Expr::Lambda(_), ValueOrLambda::Value(_)) => {
1243                            return internal_err!(
1244                                "value_fields_with_higher_order_udf returned a lambda for a value argument"
1245                            );
1246                        }
1247                        (_, ValueOrLambda::Value(_)) => {} // nothing to do
1248                    }
1249                }
1250            }
1251            LambdaParametersProgress::Complete(params) => break params,
1252        }
1253
1254        let limit = func.signature().lambda_parameters_max_iterations;
1255
1256        step += 1;
1257
1258        if step > limit {
1259            return plan_err!(
1260                "{} lambda_parameters called {limit} times without completion",
1261                func.name()
1262            );
1263        }
1264    };
1265
1266    let mut lambda_params = lambda_params.into_iter();
1267
1268    if num_lambdas != lambda_params.len() {
1269        return plan_err!(
1270            "{} lambda_parameters returned {} values for {num_lambdas} lambdas",
1271            func.name(),
1272            lambda_params.len()
1273        );
1274    }
1275
1276    let args = args.map_elements(|arg| match arg {
1277        Expr::Lambda(mut lambda) => {
1278            let lambda_params = lambda_params.next().ok_or_else(|| {
1279                internal_datafusion_err!(
1280                    "lambda_params len should have been checked above"
1281                )
1282            })?;
1283
1284            if lambda.params.len() > lambda_params.len() {
1285                return plan_err!(
1286                    "{} lambda defined {} params ({}), but only {} supported",
1287                    func.name(),
1288                    lambda.params.len(),
1289                    display_comma_separated(&lambda.params),
1290                    lambda_params.len()
1291                );
1292            }
1293
1294            if !all_unique(&lambda.params) {
1295                return plan_err!(
1296                    "lambda params must be unique, got ({})",
1297                    lambda.params.join(", ")
1298                );
1299            }
1300
1301            for (param, field) in std::iter::zip(&lambda.params, lambda_params) {
1302                vars.entry_ref(param)
1303                    .or_default()
1304                    .push(field.renamed(param.as_str()));
1305            }
1306
1307            let transformed =
1308                resolve_lambda_variables(mem::take(lambda.body.as_mut()), schema, vars)?;
1309
1310            *lambda.body = transformed.data;
1311
1312            remove_scope(vars, &lambda.params)?;
1313
1314            Ok(Transformed::new(
1315                Expr::Lambda(lambda),
1316                transformed.transformed,
1317                TreeNodeRecursion::Jump,
1318            ))
1319        }
1320        arg => Ok(Transformed::no(arg)), // resolved at the start of the function
1321    })?;
1322
1323    Ok(Transformed::new(
1324        Expr::HigherOrderFunction(HigherOrderFunction::new(func, args.data)),
1325        transformed || args.transformed,
1326        TreeNodeRecursion::Jump,
1327    ))
1328}
1329
1330fn remove_scope(
1331    vars: &mut HashMap<String, Vec<FieldRef>>,
1332    scope: &[String],
1333) -> Result<()> {
1334    for param in scope {
1335        match vars.entry_ref(param) {
1336            EntryRef::Occupied(mut v) => {
1337                if v.get().len() == 1 {
1338                    v.remove();
1339                } else {
1340                    v.get_mut().pop().ok_or_else(|| {
1341                        internal_datafusion_err!(
1342                            "every entry should have at least one field"
1343                        )
1344                    })?;
1345                }
1346            }
1347            EntryRef::Vacant(_v) => {
1348                return internal_err!("no empty value should be in the map");
1349            }
1350        }
1351    }
1352
1353    Ok(())
1354}
1355
1356fn all_unique(params: &[String]) -> bool {
1357    match params.len() {
1358        0 | 1 => true,
1359        2 => params[0] != params[1],
1360        _ => {
1361            let mut set = HashSet::with_capacity(params.len());
1362
1363            params.iter().all(|p| set.insert(p.as_str()))
1364        }
1365    }
1366}
1367
1368#[cfg(test)]
1369mod tests {
1370    use super::*;
1371    use std::hash::DefaultHasher;
1372    use std::sync::Arc;
1373
1374    use arrow_schema::{DataType, Field, FieldRef, Schema};
1375    use datafusion_common::{DFSchema, Result};
1376    use datafusion_expr_common::columnar_value::ColumnarValue;
1377    use datafusion_expr_common::signature::Volatility;
1378
1379    use crate::{
1380        Expr, HigherOrderSignature, HigherOrderUDF, HigherOrderUDFImpl,
1381        LambdaParametersProgress, ValueOrLambda, col,
1382        expr::{HigherOrderFunction, LambdaVariable},
1383        lambda, lambda_var, lit,
1384    };
1385
1386    #[derive(Debug, PartialEq, Eq, Hash)]
1387    struct TestHigherOrderUDF {
1388        name: &'static str,
1389        field: &'static str,
1390        signature: HigherOrderSignature,
1391    }
1392    impl HigherOrderUDFImpl for TestHigherOrderUDF {
1393        fn name(&self) -> &str {
1394            self.name
1395        }
1396
1397        fn signature(&self) -> &HigherOrderSignature {
1398            &self.signature
1399        }
1400
1401        fn lambda_parameters(
1402            &self,
1403            _step: usize,
1404            _fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
1405        ) -> Result<LambdaParametersProgress> {
1406            unimplemented!()
1407        }
1408
1409        fn return_field_from_args(
1410            &self,
1411            _args: HigherOrderReturnFieldArgs,
1412        ) -> Result<FieldRef> {
1413            unimplemented!()
1414        }
1415
1416        fn invoke_with_args(
1417            &self,
1418            _args: HigherOrderFunctionArgs,
1419        ) -> Result<ColumnarValue> {
1420            unimplemented!()
1421        }
1422    }
1423
1424    // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd
1425    // must be consistent, so they are tested together.
1426    #[test]
1427    fn test_partial_eq_hash_and_partial_ord() {
1428        // A parameterized function
1429        let f = test_func("foo", "a");
1430
1431        // Same like `f`, different instance
1432        let f2 = test_func("foo", "a");
1433        assert_eq!(&f, &f2);
1434        assert_eq!(hash(&f), hash(&f2));
1435        assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
1436
1437        // Different parameter
1438        let b = test_func("foo", "b");
1439        assert_ne!(&f, &b);
1440        assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test
1441        assert_eq!(f.partial_cmp(&b), None);
1442
1443        // Different name
1444        let o = test_func("other", "a");
1445        assert_ne!(&f, &o);
1446        assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test
1447        assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
1448
1449        // Different name and parameter
1450        assert_ne!(&b, &o);
1451        assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test
1452        assert_eq!(b.partial_cmp(&o), Some(Ordering::Less));
1453    }
1454
1455    fn test_func(name: &'static str, parameter: &'static str) -> Arc<HigherOrderUDF> {
1456        Arc::new(HigherOrderUDF::new_from_impl(TestHigherOrderUDF {
1457            name,
1458            field: parameter,
1459            signature: HigherOrderSignature::variadic_any(Volatility::Immutable),
1460        }))
1461    }
1462
1463    fn hash<T: Hash>(value: &T) -> u64 {
1464        let hasher = &mut DefaultHasher::new();
1465        value.hash(hasher);
1466        hasher.finish()
1467    }
1468
1469    #[derive(Debug, PartialEq, Eq, Hash)]
1470    struct MockArrayReduce {
1471        signature: HigherOrderSignature,
1472    }
1473
1474    impl HigherOrderUDFImpl for MockArrayReduce {
1475        fn name(&self) -> &str {
1476            "array_reduce"
1477        }
1478
1479        fn aliases(&self) -> &[String] {
1480            &[]
1481        }
1482
1483        fn signature(&self) -> &HigherOrderSignature {
1484            &self.signature
1485        }
1486
1487        fn lambda_parameters(
1488            &self,
1489            step: usize,
1490            fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
1491        ) -> Result<LambdaParametersProgress> {
1492            // optional finish not supported for simplicity
1493            let [
1494                ValueOrLambda::Value(list),
1495                ValueOrLambda::Value(initial_value),
1496                ValueOrLambda::Lambda(merge),
1497                ValueOrLambda::Lambda(_finish),
1498            ] = fields
1499            else {
1500                unreachable!()
1501            };
1502
1503            let list_field = match list.data_type() {
1504                DataType::List(field) => field,
1505                _ => unreachable!(),
1506            };
1507
1508            Ok(match (step, merge) {
1509                (0, None) => {
1510                    // at the first step, we use the initial_value as merge accumulator,
1511                    // and return None for finish since we don't know the output of merge
1512                    LambdaParametersProgress::Partial(vec![
1513                        // merge
1514                        Some(vec![Arc::clone(initial_value), Arc::clone(list_field)]),
1515                        // finish
1516                        None,
1517                    ])
1518                }
1519                (1, Some(accumulator)) | (0, Some(accumulator)) => {
1520                    // now we can use the merge output as it's accumulator and
1521                    // as the finish parameter
1522                    LambdaParametersProgress::Complete(vec![
1523                        // merge
1524                        vec![Arc::clone(accumulator), Arc::clone(list_field)],
1525                        // finish
1526                        vec![Arc::clone(accumulator)],
1527                    ])
1528                }
1529                (1, None) => {
1530                    unreachable!()
1531                }
1532                _ => unreachable!(),
1533            })
1534        }
1535
1536        fn return_field_from_args(
1537            &self,
1538            args: HigherOrderReturnFieldArgs,
1539        ) -> Result<FieldRef> {
1540            // optional finish not supported for simplicity
1541            let [
1542                ValueOrLambda::Value(_list),
1543                ValueOrLambda::Value(_initial_value),
1544                ValueOrLambda::Lambda(_merge),
1545                ValueOrLambda::Lambda(finish),
1546            ] = args.arg_fields
1547            else {
1548                unreachable!()
1549            };
1550
1551            Ok(Arc::clone(finish))
1552        }
1553
1554        fn invoke_with_args(
1555            &self,
1556            _args: HigherOrderFunctionArgs,
1557        ) -> Result<ColumnarValue> {
1558            unreachable!()
1559        }
1560    }
1561
1562    #[test]
1563    fn test_resolve_lambda_variables() {
1564        let schema = DFSchema::try_from(Schema::new(vec![Field::new(
1565            "c",
1566            DataType::new_list(DataType::new_list(DataType::Int32, true), true),
1567            true,
1568        )]))
1569        .unwrap();
1570
1571        let func = Arc::new(HigherOrderUDF::new_from_impl(MockArrayReduce {
1572            signature: HigherOrderSignature::variadic_any(Volatility::Immutable),
1573        }));
1574
1575        /*
1576           array_reduce(
1577               c,
1578               0,
1579               (acc1, v) -> acc + array_reduce(
1580                   v,
1581                   0,
1582                   (acc2, v) -> acc2 + acc1 + v,
1583                   reduced -> reduced * 2.0
1584               ),
1585               reduced -> reduced * 2
1586           )
1587        */
1588        let expr = Expr::HigherOrderFunction(HigherOrderFunction::new(
1589            Arc::clone(&func),
1590            vec![
1591                col("c"),
1592                lit(0),
1593                lambda(
1594                    ["acc1", "v"],
1595                    lambda_var("acc1")
1596                        + Expr::HigherOrderFunction(HigherOrderFunction::new(
1597                            Arc::clone(&func),
1598                            vec![
1599                                lambda_var("v"),
1600                                lit(0),
1601                                lambda(
1602                                    ["acc2", "v"],
1603                                    lambda_var("acc2")
1604                                        + lambda_var("acc1")
1605                                        + lambda_var("v"),
1606                                ),
1607                                lambda(["reduced"], lambda_var("reduced") * lit(2.0)),
1608                            ],
1609                        )),
1610                ),
1611                lambda(["reduced"], lambda_var("reduced") * lit(2)),
1612            ],
1613        ));
1614
1615        let resolved_expr = expr.resolve_lambda_variables(&schema).unwrap().data;
1616
1617        /*
1618           array_reduce(
1619               c@[[Int32]],
1620               0@Int64,
1621               (acc1@Float64, v@[Int32]) -> acc@Float64 + array_reduce(
1622                   v@[Int32],
1623                   0@Int64,
1624                   (acc2@Float64, v@Int32) -> acc2@Float64 + acc1@Float64 + v@Int32,
1625                   reducedFloat64 -> reduced@Float64 * 2.0@Float64
1626               ),
1627               reduced@Float64 -> reduced@Float64 * 2@Int64
1628           )
1629        */
1630        let expected = Expr::HigherOrderFunction(HigherOrderFunction::new(
1631            Arc::clone(&func),
1632            vec![
1633                col("c"),
1634                lit(0),
1635                lambda(
1636                    ["acc1", "v"],
1637                    resolved_lambda_var("acc1", DataType::Float64, true)
1638                        + Expr::HigherOrderFunction(HigherOrderFunction::new(
1639                            Arc::clone(&func),
1640                            vec![
1641                                resolved_lambda_var(
1642                                    "v",
1643                                    DataType::new_list(DataType::Int32, true),
1644                                    true,
1645                                ),
1646                                lit(0),
1647                                lambda(
1648                                    ["acc2", "v"],
1649                                    resolved_lambda_var("acc2", DataType::Float64, true)
1650                                        + resolved_lambda_var(
1651                                            "acc1",
1652                                            DataType::Float64,
1653                                            true,
1654                                        )
1655                                        + resolved_lambda_var("v", DataType::Int32, true),
1656                                ),
1657                                lambda(
1658                                    ["reduced"],
1659                                    resolved_lambda_var(
1660                                        "reduced",
1661                                        DataType::Float64,
1662                                        true,
1663                                    ) * lit(2.0),
1664                                ),
1665                            ],
1666                        )),
1667                ),
1668                lambda(
1669                    ["reduced"],
1670                    resolved_lambda_var("reduced", DataType::Float64, true) * lit(2),
1671                ),
1672            ],
1673        ));
1674
1675        assert_eq!(resolved_expr, expected);
1676    }
1677
1678    fn resolved_lambda_var(name: &str, dt: DataType, nullable: bool) -> Expr {
1679        Expr::LambdaVariable(LambdaVariable::new(
1680            name.into(),
1681            Some(Arc::new(Field::new(name, dt, nullable))),
1682        ))
1683    }
1684}