datafusion_expr/higher_order_function.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`HigherOrderUDF`]: User Defined Higher Order Functions
19
20use crate::expr::{
21 HigherOrderFunction, display_comma_separated,
22 schema_name_from_exprs_comma_separated_without_space,
23};
24use crate::type_coercion::functions::value_fields_with_higher_order_udf;
25use crate::udf_eq::UdfEq;
26use crate::{ColumnarValue, Documentation, Expr, ExprSchemable};
27use arrow::array::{ArrayRef, RecordBatch};
28use arrow::datatypes::{DataType, FieldRef, Schema};
29use arrow_schema::SchemaRef;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::datatype::FieldExt;
32use datafusion_common::hash_map::EntryRef;
33use datafusion_common::tree_node::{
34 Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion,
35};
36use datafusion_common::{
37 DFSchema, HashMap, HashSet, Result, ScalarValue, exec_err, internal_datafusion_err,
38 internal_err, not_impl_err, plan_datafusion_err, plan_err,
39};
40use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
41use datafusion_expr_common::signature::Volatility;
42use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
43use std::any::Any;
44use std::cmp::Ordering;
45use std::fmt::Debug;
46use std::hash::{Hash, Hasher};
47use std::mem;
48use std::sync::Arc;
49
50/// The types of arguments for which a function has implementations.
51///
52/// [`HigherOrderTypeSignature`] **DOES NOT** define the types that a user query could call the
53/// function with. DataFusion will automatically coerce (cast) argument types to
54/// one of the supported function signatures, if possible.
55///
56/// # Overview
57/// Functions typically provide implementations for a small number of different
58/// argument [`DataType`]s, rather than all possible combinations. If a user
59/// calls a function with arguments that do not match any of the declared types,
60/// DataFusion will attempt to automatically coerce (add casts to) function
61/// arguments so they match the [`HigherOrderTypeSignature`]. See the [`type_coercion`] module
62/// for more details
63///
64/// [`type_coercion`]: crate::type_coercion
65#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
66pub enum HigherOrderTypeSignature {
67 /// The acceptable signature and coercions rules are special for this
68 /// function.
69 ///
70 /// If this signature is specified,
71 /// DataFusion will call [`HigherOrderUDFImpl::coerce_value_types`] to prepare argument types.
72 UserDefined,
73 /// One or more lambdas or arguments with arbitrary types
74 VariadicAny,
75 /// The specified number of lambdas or arguments with arbitrary types.
76 Any(usize),
77 /// Exactly the specified arguments in the given order, with arbitrary types.
78 /// DataFusion will call [`HigherOrderUDFImpl::coerce_value_types`] to prepare the value
79 /// argument types.
80 Exact(Vec<ValueOrLambda<(), ()>>),
81}
82
83/// Provides information necessary for calling a higher order function.
84///
85/// - [`HigherOrderTypeSignature`] defines the argument types that a function has implementations
86/// for.
87///
88/// - [`Volatility`] defines how the output of the function changes with the input.
89#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
90pub struct HigherOrderSignature {
91 /// The data types that the function accepts. See [HigherOrderTypeSignature] for more information.
92 pub type_signature: HigherOrderTypeSignature,
93 /// The volatility of the function. See [Volatility] for more information.
94 pub volatility: Volatility,
95 /// The max number of times to call [HigherOrderUDFImpl::lambda_parameters] before raising an error.
96 /// Used to guard against implementations that causes an infinite loop by endlessly returning
97 /// [LambdaParametersProgress::Partial]. Defaults to 256
98 pub lambda_parameters_max_iterations: usize,
99}
100
101const LAMBDA_PARAMETERS_MAX_ITERATIONS: usize = 256;
102
103impl HigherOrderSignature {
104 /// Creates a new `HigherOrderSignature` from a given type signature and volatility.
105 pub fn new(type_signature: HigherOrderTypeSignature, volatility: Volatility) -> Self {
106 HigherOrderSignature {
107 type_signature,
108 volatility,
109 lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
110 }
111 }
112
113 /// User-defined coercion rules for the function.
114 pub fn user_defined(volatility: Volatility) -> Self {
115 Self {
116 type_signature: HigherOrderTypeSignature::UserDefined,
117 volatility,
118 lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
119 }
120 }
121
122 /// An arbitrary number of lambdas or arguments of any type.
123 pub fn variadic_any(volatility: Volatility) -> Self {
124 Self {
125 type_signature: HigherOrderTypeSignature::VariadicAny,
126 volatility,
127 lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
128 }
129 }
130
131 /// A specified number of arguments of any type
132 pub fn any(arg_count: usize, volatility: Volatility) -> Self {
133 Self {
134 type_signature: HigherOrderTypeSignature::Any(arg_count),
135 volatility,
136 lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
137 }
138 }
139
140 /// Exactly the specified arguments in the given order, with arbitrary types.
141 /// DataFusion will call [`HigherOrderUDFImpl::coerce_value_types`] to prepare the value
142 /// argument types.
143 ///
144 /// # Example
145 /// A function that takes one value argument followed by one lambda:
146 /// ```
147 /// # use datafusion_expr::{HigherOrderSignature, ValueOrLambda, Volatility};
148 /// let sig = HigherOrderSignature::exact(
149 /// vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
150 /// Volatility::Immutable,
151 /// );
152 /// ```
153 pub fn exact(args: Vec<ValueOrLambda<(), ()>>, volatility: Volatility) -> Self {
154 Self {
155 type_signature: HigherOrderTypeSignature::Exact(args),
156 volatility,
157 lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
158 }
159 }
160}
161
162impl PartialEq for dyn HigherOrderUDFImpl {
163 fn eq(&self, other: &Self) -> bool {
164 self.dyn_eq(other as _)
165 }
166}
167
168impl PartialOrd for dyn HigherOrderUDFImpl {
169 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
170 let mut cmp = self.name().cmp(other.name());
171 if cmp == Ordering::Equal {
172 cmp = self.signature().partial_cmp(other.signature())?;
173 }
174 if cmp == Ordering::Equal {
175 cmp = self.aliases().partial_cmp(other.aliases())?;
176 }
177 // Contract for PartialOrd and PartialEq consistency requires that
178 // a == b if and only if partial_cmp(a, b) == Some(Equal).
179 if cmp == Ordering::Equal && self != other {
180 // Functions may have other properties besides name and signature
181 // that differentiate two instances (e.g. type, or arbitrary parameters).
182 // We cannot return Some(Equal) in such case.
183 return None;
184 }
185 debug_assert!(
186 cmp == Ordering::Equal || self != other,
187 "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
188 The functions compare as equal, but they are not equal based on general properties that \
189 the PartialOrd implementation observes,",
190 self.name(),
191 other.name()
192 );
193 Some(cmp)
194 }
195}
196
197impl Eq for dyn HigherOrderUDFImpl {}
198
199impl Hash for dyn HigherOrderUDFImpl {
200 fn hash<H: Hasher>(&self, state: &mut H) {
201 self.dyn_hash(state)
202 }
203}
204
205/// Arguments passed to [`HigherOrderUDFImpl::invoke_with_args`] when invoking a
206/// higher order function.
207#[derive(Debug, Clone)]
208pub struct HigherOrderFunctionArgs {
209 /// The evaluated arguments and lambdas to the function
210 pub args: Vec<ValueOrLambda<ColumnarValue, LambdaArgument>>,
211 /// Field associated with each arg, if it exists
212 /// For lambdas, it will be the field of the result of
213 /// the lambda if evaluated with the parameters
214 /// returned from [`HigherOrderUDFImpl::lambda_parameters`]
215 pub arg_fields: Vec<ValueOrLambda<FieldRef, FieldRef>>,
216 /// The number of rows in record batch being evaluated
217 pub number_rows: usize,
218 /// The return field of the higher order function returned
219 /// (from `return_field_from_args`) when creating the
220 /// physical expression from the logical expression
221 pub return_field: FieldRef,
222 /// The config options at execution time
223 pub config_options: Arc<ConfigOptions>,
224}
225
226impl HigherOrderFunctionArgs {
227 /// The return type of the function. See [`Self::return_field`] for more
228 /// details.
229 pub fn return_type(&self) -> &DataType {
230 self.return_field.data_type()
231 }
232}
233
234/// A lambda argument to a HigherOrderFunction
235#[derive(Clone, Debug)]
236pub struct LambdaArgument {
237 /// The parameters defined in this lambda
238 ///
239 /// For example, for `array_transform([2], v -> -v)`,
240 /// this will be `vec![Field::new("v", DataType::Int32, true)]`
241 params: Vec<FieldRef>,
242 /// The body of the lambda
243 ///
244 /// For example, for `array_transform([2], v -> -v)`,
245 /// this will be the physical expression of `-v`
246 body: Arc<dyn PhysicalExpr>,
247 /// Cached schema built from `params`. Reused across every `evaluate` call
248 /// (and across every nested-list iteration when the lambda is called once
249 /// per outer sublist), avoiding the per-call `Schema::new` build that
250 /// includes constructing the internal name -> index map.
251 schema: SchemaRef,
252 /// A RecordBatch containing the captured columns inside this lambda body, if any
253 ///
254 /// For example, for `array_transform([2], v -> v + a + b)`,
255 /// this will be a `RecordBatch` with two columns, `a` and `b`
256 captures: Option<RecordBatch>,
257}
258
259impl LambdaArgument {
260 pub fn new(
261 params: Vec<FieldRef>,
262 body: Arc<dyn PhysicalExpr>,
263 captures: Option<RecordBatch>,
264 ) -> Self {
265 let fields = match &captures {
266 Some(batch) => batch
267 .schema_ref()
268 .fields()
269 .iter()
270 .cloned()
271 .chain(params.clone())
272 .collect(),
273 None => params.clone(),
274 };
275
276 let schema = Arc::new(Schema::new(fields));
277
278 Self {
279 params,
280 body,
281 schema,
282 captures,
283 }
284 }
285
286 /// Evaluate this lambda
287 /// `args` should evaluate to the value of each parameter
288 /// of the correspondent lambda returned in [HigherOrderUDFImpl::lambda_parameters].
289 ///
290 /// `spread_captures` is responsible for transforming the captured column arrays
291 /// so they align with the evaluation batch. Captures are snapshotted from the
292 /// outer batch at construction time, giving one value per outer row, but the
293 /// function may evaluate the lambda body over a batch with a different number
294 /// of rows. It is the function's responsibility to provide the appropriate
295 /// `spread_captures` closure to expand (or otherwise reshape) the captures
296 /// to match.
297 ///
298 /// Taking as an example the following table:
299 ///
300 /// ```sql
301 /// CREATE TABLE t (arr INT[], a INT) AS VALUES
302 /// ([1, 2, 3], 10),
303 /// ([], 20),
304 /// ([4], 30);
305 /// ```
306 ///
307 /// `SELECT array_transform(arr, v -> v + a) from t` would execute over three outer rows:
308 ///
309 /// ```text
310 /// arr (ListArray): [[1, 2, 3], [], [4]] -- 3 outer rows, 4 total elements
311 /// a (captured): [10, 20, 30] -- one value per outer row
312 /// ```
313 ///
314 /// `array_transform` flattens the list elements into a single batch of 4 rows,
315 /// so `spread_captures` must repeat/drop captured values to match:
316 ///
317 /// ```text
318 /// v (flattened args): [1, 2, 3, 4]
319 /// a (spread): [10, 10, 10, 30] -- 10 repeated for 3 elements in row 0,
320 /// -- 20 dropped for the empty sublist in row 1,
321 /// -- 30 once for the single element in row 2
322 /// ```
323 ///
324 /// The lambda body `v + a` then evaluates element-wise over these 4-row arrays,
325 /// producing `[11, 12, 13, 34]`, which `array_transform` reassembles into `[[11, 12, 13], [], [34]]`.
326 ///
327 /// If the lambda has no captures, `spread_captures` is never called.
328 pub fn evaluate(
329 &self,
330 args: &[&dyn Fn() -> Result<ArrayRef>],
331 spread_captures: impl FnOnce(&[ArrayRef]) -> Result<Vec<ArrayRef>>,
332 ) -> Result<ColumnarValue> {
333 let spread_captures = self
334 .captures
335 .as_ref()
336 .map(|captures| {
337 let spread_columns = spread_captures(captures.columns())?;
338
339 RecordBatch::try_new(captures.schema(), spread_columns)
340 })
341 .transpose()?;
342
343 let merged = merge_captures_with_variables(
344 spread_captures.as_ref(),
345 Arc::clone(&self.schema),
346 &self.params,
347 args,
348 )?;
349
350 self.body.evaluate(&merged)
351 }
352}
353
354fn merge_captures_with_variables(
355 captures: Option<&RecordBatch>,
356 schema: SchemaRef,
357 params: &[FieldRef],
358 variables: &[&dyn Fn() -> Result<ArrayRef>],
359) -> Result<RecordBatch> {
360 if variables.len() < params.len() {
361 return exec_err!(
362 "expected at least {} lambda arguments to merge with captures, got {}",
363 params.len(),
364 variables.len()
365 );
366 }
367
368 let columns = match captures {
369 Some(captures) => {
370 let mut columns = captures.columns().to_vec();
371
372 for arg in &variables[..params.len()] {
373 columns.push(arg()?);
374 }
375
376 columns
377 }
378 None => variables
379 .iter()
380 .take(params.len())
381 .map(|arg| arg())
382 .collect::<Result<_>>()?,
383 };
384
385 Ok(RecordBatch::try_new(schema, columns)?)
386}
387
388/// Information about arguments passed to the function
389///
390/// This structure contains metadata about how the function was called
391/// such as the type of the arguments, any scalar arguments and if the
392/// arguments can (ever) be null
393///
394/// See [`HigherOrderUDFImpl::return_field_from_args`] for more information
395#[derive(Clone, Debug)]
396pub struct HigherOrderReturnFieldArgs<'a> {
397 /// The data types of the arguments to the function
398 ///
399 /// If argument `i` to the function is a lambda, it will be the field of the result of the
400 /// lambda if evaluated with the parameters returned from [`HigherOrderUDFImpl::lambda_parameters`]
401 ///
402 /// For example, with `array_transform([1], v -> v == 5)`
403 /// this field will be
404 /// ```ignore
405 /// [
406 /// ValueOrLambda::Value(Field::new("", DataType::new_list(DataType::Int32, true), true)),
407 /// ValueOrLambda::Lambda(Field::new("", DataType::Boolean, true))
408 /// ]
409 /// ```
410 pub arg_fields: &'a [ValueOrLambda<FieldRef, FieldRef>],
411 /// Is argument `i` to the function a scalar (constant)?
412 ///
413 /// If the argument `i` is not a scalar, it will be None
414 ///
415 /// For example, if a function is called like `array_transform([1], v -> v == 5)`
416 /// this field will be `[Some(ScalarValue::List(...), None]`
417 pub scalar_arguments: &'a [Option<&'a ScalarValue>],
418}
419
420/// An argument to a higher order function
421#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)]
422pub enum ValueOrLambda<V, L> {
423 /// A value with associated data
424 Value(V),
425 /// A lambda with associated data
426 Lambda(L),
427}
428
429/// Represents a step during the resolution of the parameters of all lambdas of a given
430/// higher-order function via [HigherOrderUDFImpl::lambda_parameters]. It's valid that the
431/// fields of a given lambda changes between steps, and is up to the implementation to
432/// provide during the function evaluation the parameters that matches the fields returned
433/// at the [LambdaParametersProgress::Complete] step. See [HigherOrderUDFImpl::lambda_parameters]
434/// docs for more details
435pub enum LambdaParametersProgress {
436 /// The parameters of some lambdas are unknown due to a dependency on another lambda output field
437 /// or are placeholders due to a dependency on it's own output field. It's perfectly valid to
438 /// contain only `Some`'s and not a single `None`, representing lambdas that depends only on itself
439 /// and not on others. [HigherOrderUDFImpl::lambda_parameters] will be called again with the output
440 /// field of all lambdas with known parameters.
441 Partial(Vec<Option<Vec<FieldRef>>>),
442 /// There are no unmet dependencies and all parameters are known, [HigherOrderUDFImpl::lambda_parameters]
443 /// will not be called again
444 Complete(Vec<Vec<FieldRef>>),
445}
446
447/// Trait for implementing user defined higher order functions.
448///
449/// This trait exposes the full API for implementing user defined functions and
450/// can be used to implement any function.
451///
452/// New higher order functions typically implement this trait and are then
453/// wrapped in a [`HigherOrderUDF`] for registration with DataFusion.
454///
455/// See [`array_transform.rs`] for a commented complete implementation
456///
457/// [`array_transform.rs`]: https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/array_transform.rs
458pub trait HigherOrderUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any {
459 /// Returns this function's name
460 fn name(&self) -> &str;
461
462 /// Returns any aliases (alternate names) for this function.
463 ///
464 /// Aliases can be used to invoke the same function using different names.
465 /// For example in some databases `now()` and `current_timestamp()` are
466 /// aliases for the same function. This behavior can be obtained by
467 /// returning `current_timestamp` as an alias for the `now` function.
468 ///
469 /// Note: `aliases` should only include names other than [`Self::name`].
470 /// Defaults to `[]` (no aliases)
471 fn aliases(&self) -> &[String] {
472 &[]
473 }
474
475 /// Returns the name of the column this expression would create
476 ///
477 /// See [`Expr::schema_name`] for details
478 fn schema_name(&self, args: &[Expr]) -> Result<String> {
479 Ok(format!(
480 "{}({})",
481 self.name(),
482 schema_name_from_exprs_comma_separated_without_space(args)?
483 ))
484 }
485
486 /// Returns a [`HigherOrderSignature`] describing the argument types for which this
487 /// function has an implementation, and the function's [`Volatility`].
488 ///
489 /// See [`HigherOrderSignature`] for more details on argument type handling
490 /// and [`Self::return_field_from_args`] for computing the return type.
491 ///
492 /// [`Volatility`]: datafusion_expr_common::signature::Volatility
493 fn signature(&self) -> &HigherOrderSignature;
494
495 /// Return the field of all the parameters supported by the lambdas in `fields`.
496 /// If a lambda support multiple parameters, all should be returned, regardless of
497 /// whether they are used or not on a particular invocation
498 ///
499 /// Tip: If you have a [`HigherOrderFunction`] invocation, you can call the helper
500 /// [`HigherOrderFunction::lambda_parameters`] instead of this method directly
501 ///
502 /// The name of the returned fields are ignored.
503 ///
504 /// This function is repeatedelly called until [LambdaParametersProgress::Complete] is returned, with
505 /// `step` increased by one at each invocation, starting at 0.
506 ///
507 /// For functions which all lambda parameters depend only on the field of it's value arguments,
508 /// this can return [LambdaParametersProgress::Complete] at step 0. Taking as an example a strict
509 /// array_reduce with the signature `(arr: [V], initial_value: I, (I, V) -> I, (I) -> O) -> O`, which
510 /// requires it's initial value to be the exact same type of it's merge output, which is also the
511 /// parameter of it's finish lambda, the expression
512 ///
513 /// `array_reduce([1.2, 2.1], 0.0, (acc, v) -> acc + v + 1.5, v -> v > 5.1)`
514 ///
515 /// would result in this function being called as the following:
516 ///
517 /// ```ignore
518 /// let lambda_parameters = array_reduce.lambda_parameters(
519 /// 0,
520 /// &[
521 /// // the Field of the literal `[1.2, 2.1]`, the array being reduced
522 /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::new_list(DataType::Float32, true), true))),
523 /// // the Field of the literal `0.0`, the initial value
524 /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::Float32, true))),
525 /// // the Field of the output of the merge lambda, which is unknown at this point because it depends
526 /// // on the return of this call
527 /// ValueOrLambda::Lambda(None),
528 /// // the Field of the output of the finish lambda, unknown for the same reason as above
529 /// ValueOrLambda::Lambda(None),
530 /// ])?;
531 ///
532 /// assert_eq!(
533 /// lambda_parameters,
534 /// LambdaParametersProgress::Complete(vec![
535 /// // the finish lambda supported parameters, regardless of how many are actually used
536 /// vec![
537 /// // the accumulator which is the field of the initial value
538 /// Arc::new(Field::new("ignored_name", DataType::Float32, true)),
539 /// // the array values being reduced
540 /// Arc::new(Field::new("", DataType::Float32, true)),
541 /// ],
542 /// // the merge lambda supported parameters
543 /// vec![
544 /// // the reduced value which is the field of the initial value
545 /// Arc::new(Field::new("ignored_name", DataType::Float32, true)),
546 /// ],
547 /// ])
548 /// );
549 /// ```
550 ///
551 /// For functions which lambda parameters depends on the output of other lambdas, or on their own lambda,
552 /// this can return [LambdaParametersProgress::Partial] until all dependencies are met. Note that for
553 /// lambda with cyclic dependencies, you likely want to use [HigherOrderUDFImpl::coerce_values_for_lambdas] too.
554 /// Take as an example a flexible array_reduce with the signature `(arr: [V], initial_value: I, (ACC, V) -> ACC, (ACC) -> O) -> O`.
555 /// It has a cyclic dependency in the merge lambda, and a dependency of the finish lambda in the merge lambda,
556 /// and only requires the initial value to be *coercible* to the output of the merge lambda, which is defined by
557 /// it's [HigherOrderUDFImpl::coerce_values_for_lambdas] implementation. The expression
558 ///
559 /// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 5.1)`
560 ///
561 /// would result in this function being called as the following:
562 ///
563 /// ```ignore
564 /// let lambda_parameters = array_reduce.lambda_parameters(
565 /// 0,
566 /// &[
567 /// // the Field of the literal `[1.2, 2.1]`, the array being reduced
568 /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::new_list(DataType::Float32, true), true))),
569 /// // the Field of the literal `0`, the initial value
570 /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, true))),
571 /// // the Field of the output of the merge lambda, which is unknown at this point because it depends on
572 /// // the return this call
573 /// ValueOrLambda::Lambda(None),
574 /// // the Field of the output of the finish lambda, unknown for the same reason as above
575 /// ValueOrLambda::Lambda(None),
576 /// ])?;
577 ///
578 /// assert_eq!(
579 /// lambda_parameters,
580 /// LambdaParametersProgress::Partial(vec![
581 /// // the finish lambda supported parameters, regardless of how many are actually used
582 /// Some(vec![
583 /// // at step 0, use the field of the initial value
584 /// Arc::new(Field::new("ignored_name", DataType::Int32, true)),
585 /// // the array values being reduced
586 /// Arc::new(Field::new("", DataType::Float32, true)),
587 /// ]),
588 /// // the merge lambda supported parameters, unknown at this point due to dependency on the merge output
589 /// None,
590 /// ])
591 /// );
592 ///
593 /// let lambda_parameters = array_reduce.lambda_parameters(
594 /// 1,
595 /// &[
596 /// // the Field of the literal `[1.2, 2.1]`, the array being reduced
597 /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::new_list(DataType::Float32, true), true))),
598 /// // the Field of the literal `0`, the initial value
599 /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, true))),
600 /// // the Field of the output of the merge lambda, which could be inferred to be a Float32 based on the
601 /// // returned values of the previous step
602 /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::Float32, true))),
603 /// // the Field of the output of the finish lambda, which is unknown at this point because it depends
604 /// // on the return of this call
605 /// ValueOrLambda::Lambda(None),
606 /// ])?;
607 ///
608 /// assert_eq!(
609 /// lambda_parameters,
610 /// LambdaParametersProgress::Complete(vec![
611 /// // the finish lambda supported parameters, regardless of how many are actually used
612 /// vec![
613 /// // the finish lambda own output now used as it's accumulator
614 /// Arc::new(Field::new("ignored_name", DataType::Float32, true)),
615 /// // the array values being reduced
616 /// Arc::new(Field::new("", DataType::Float32, true)),
617 /// ],
618 /// // the merge lambda supported parameters, which is the output of the merge lambda,
619 /// vec![
620 /// // the output of the merge lambda
621 /// Arc::new(Field::new("", DataType::Float32, true)),
622 /// ],
623 /// ])
624 /// );
625 ///
626 /// let coerce_to = array_reduce.coerce_values_for_lambdas(&[
627 /// // the literal `[1.2, 2.1]` data type, the array being reduced
628 /// ValueOrLambda::Value(DataType::new_list(DataType::Float32, true)),
629 /// // the literal `0` data type, the initial value
630 /// ValueOrLambda::Value(DataType::Int32),
631 /// // the output data type of the merge lambda
632 /// ValueOrLambda::Lambda(DataType::Float32),
633 /// // the output data type of the finish lambda
634 /// ValueOrLambda::Lambda(DataType::Boolean),
635 /// ])?;
636 ///
637 /// assert_eq!(
638 /// coerce_to,
639 /// Some(vec![
640 /// // return the same type for the array being reduced
641 /// DataType::new_list(DataType::Float32, true),
642 /// // coerce the initial value to the output of the merge lambda
643 /// DataType::Float32,
644 /// ])
645 /// );
646 ///
647 /// ```
648 ///
649 /// Note this may also be called at step 0 with all lambda outputs already set, and in that case,
650 /// [LambdaParametersProgress::Complete] must be returned
651 ///
652 /// The implementation can assume that some other part of the code has coerced
653 /// the actual argument types to match [`Self::signature`], except the coercion defined by
654 /// [Self::coerce_values_for_lambdas].
655 ///
656 /// [`HigherOrderFunction`]: crate::expr::HigherOrderFunction
657 /// [`HigherOrderFunction::lambda_parameters`]: crate::expr::HigherOrderFunction::lambda_parameters
658 fn lambda_parameters(
659 &self,
660 step: usize,
661 fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
662 ) -> Result<LambdaParametersProgress>;
663
664 /// Coerce value arguments of a function call to types that the function can evaluate also taking into
665 /// account the *output type of it's lambdas*. This differs from [HigherOrderUDFImpl::coerce_value_types]
666 /// that only has access to the type of it's value arguments because it's called before the output type
667 /// of lambdas are known.
668 ///
669 /// See the [type coercion module](crate::type_coercion)
670 /// documentation for more details on type coercion
671 ///
672 /// # Parameters
673 /// * `fields`: The argument types of the value arguments of this function, or the output type of lambdas
674 ///
675 /// # Return value
676 /// If `Some`, contains a Vec with the same number of [ValueOrLambda::Value] in `fields`.
677 /// DataFusion will `CAST` the function call arguments to these specific types. If `None`, no
678 /// coercion will be applied beyond the one defined by the function signature.
679 ///
680 /// For example, a flexible array_reduce implementation (see [Self::lambda_parameters] docs), when working
681 /// with the expression below, may want to coerce it's initial value argument, the *integer* `0`,
682 /// to match the output of it's merge function, which is a *float*:
683 ///
684 /// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 2.0)`
685 fn coerce_values_for_lambdas(
686 &self,
687 _fields: &[ValueOrLambda<DataType, DataType>],
688 ) -> Result<Option<Vec<DataType>>> {
689 Ok(None)
690 }
691
692 /// What type will be returned by this function, given the arguments?
693 ///
694 /// The implementation can assume that some other part of the code has coerced
695 /// the actual argument types to match [`Self::signature`], including the coercion
696 /// defined by [Self::coerce_values_for_lambdas].
697 ///
698 /// # Example creating `Field`
699 ///
700 /// Note the name of the `Field` is ignored, except for structured types such as
701 /// `DataType::Struct`.
702 ///
703 /// ```rust
704 /// # use std::sync::Arc;
705 /// # use arrow::datatypes::{DataType, Field, FieldRef};
706 /// # use datafusion_common::Result;
707 /// # use datafusion_expr::HigherOrderReturnFieldArgs;
708 /// # struct Example{}
709 /// # impl Example {
710 /// fn return_field_from_args(&self, args: HigherOrderReturnFieldArgs) -> Result<FieldRef> {
711 /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true));
712 /// Ok(field)
713 /// }
714 /// # }
715 /// ```
716 fn return_field_from_args(
717 &self,
718 args: HigherOrderReturnFieldArgs,
719 ) -> Result<FieldRef>;
720
721 /// Whether List or LargeList arguments should have it's non-empty null
722 /// sublists cleaned with [remove_list_null_values] before invoking this function
723 ///
724 /// The default implementation always returns true and should only be implemented
725 /// if you want to handle non-empty null sublists yourself
726 ///
727 /// [remove_list_null_values]: datafusion_common::utils::remove_list_null_values
728 // todo: extend this to listview and maps when remove_list_null_values supports it
729 fn clear_null_values(&self) -> bool {
730 true
731 }
732
733 /// Invoke the function returning the appropriate result.
734 ///
735 /// # Performance
736 ///
737 /// For the best performance, the implementations should handle the common case
738 /// when one or more of their arguments are constant values (aka
739 /// [`ColumnarValue::Scalar`]).
740 ///
741 /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
742 /// to arrays, which will likely be simpler code, but be slower.
743 fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue>;
744
745 /// Returns true if some of this `exprs` subexpressions may not be evaluated
746 /// and thus any side effects (like divide by zero) may not be encountered.
747 ///
748 /// Setting this to true prevents certain optimizations such as common
749 /// subexpression elimination
750 ///
751 /// When overriding this function to return `true`, [HigherOrderUDFImpl::conditional_arguments] can also be
752 /// overridden to report more accurately which arguments are eagerly evaluated and which ones
753 /// lazily.
754 fn short_circuits(&self) -> bool {
755 false
756 }
757
758 /// Determines which of the arguments passed to *this higher-order function*
759 /// are evaluated eagerly and which may be evaluated lazily. Note that this
760 /// does *not* applies to the arguments that *lambda functions* pass to it's
761 /// body expression
762 ///
763 /// If this function returns `None`, all arguments are eagerly evaluated.
764 /// Returning `None` is a micro optimization that saves a needless `Vec`
765 /// allocation.
766 ///
767 /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager`
768 /// are the arguments that are always evaluated, and `lazy` are the
769 /// arguments that may be evaluated lazily (i.e. may not be evaluated at all
770 /// in some cases).
771 ///
772 /// Implementations must ensure that the two returned `Vec`s are disjunct,
773 /// and that each argument from `args` is present in one the two `Vec`s.
774 ///
775 /// When overriding this function, [HigherOrderUDFImpl::short_circuits] must
776 /// be overridden to return `true`.
777 fn conditional_arguments<'a>(
778 &self,
779 args: &'a [Expr],
780 ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
781 if self.short_circuits() {
782 Some((vec![], args.iter().collect()))
783 } else {
784 None
785 }
786 }
787
788 /// Coerce value arguments of a function call to types that the function can evaluate.
789 /// Note that if you need to coerce values based on the output type of lambdas, you
790 /// must use [HigherOrderUDFImpl::coerce_values_for_lambdas], as this function is used before
791 /// the output type of lambdas are known
792 ///
793 /// See the [type coercion module](crate::type_coercion)
794 /// documentation for more details on type coercion
795 ///
796 /// For example, if your function requires a contiguous list argument, but the user calls
797 /// it like `my_func(c, v -> v+2)` (i.e. with `c` as a ListView), coerce_types can return `[DataType::List(..)]`
798 /// to ensure the argument is converted to a List
799 ///
800 /// # Parameters
801 /// * `arg_types`: The argument types of the value arguments of this function, excluding lambdas
802 ///
803 /// # Return value
804 /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
805 /// arguments to these specific types.
806 fn coerce_value_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
807 not_impl_err!(
808 "Function {} does not implement coerce_value_types",
809 self.name()
810 )
811 }
812
813 /// Returns the documentation for this function.
814 ///
815 /// Documentation can be accessed programmatically as well as generating
816 /// publicly facing documentation.
817 fn documentation(&self) -> Option<&Documentation> {
818 None
819 }
820}
821
822/// Logical representation of a Higher Order User Defined Function.
823///
824/// A higher order function takes one or more lambda arguments in addition to
825/// regular value arguments. This struct contains the information DataFusion
826/// needs to plan and invoke functions you supply such as name, type signature,
827/// return type, and actual implementation.
828#[derive(Debug, Clone)]
829pub struct HigherOrderUDF {
830 inner: Arc<dyn HigherOrderUDFImpl>,
831}
832
833impl PartialEq for HigherOrderUDF {
834 fn eq(&self, other: &Self) -> bool {
835 self.inner.as_ref().dyn_eq(other.inner.as_ref())
836 }
837}
838
839impl PartialOrd for HigherOrderUDF {
840 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
841 let mut cmp = self.name().cmp(other.name());
842 if cmp == Ordering::Equal {
843 cmp = self.signature().partial_cmp(other.signature())?;
844 }
845 if cmp == Ordering::Equal {
846 cmp = self.aliases().partial_cmp(other.aliases())?;
847 }
848 // Contract for PartialOrd and PartialEq consistency requires that
849 // a == b if and only if partial_cmp(a, b) == Some(Equal).
850 if cmp == Ordering::Equal && self != other {
851 // Functions may have other properties besides name and signature
852 // that differentiate two instances (e.g. type, or arbitrary parameters).
853 // We cannot return Some(Equal) in such case.
854 return None;
855 }
856 debug_assert!(
857 cmp == Ordering::Equal || self != other,
858 "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
859 The functions compare as equal, but they are not equal based on general properties that \
860 the PartialOrd implementation observes,",
861 self.name(),
862 other.name()
863 );
864 Some(cmp)
865 }
866}
867
868impl Eq for HigherOrderUDF {}
869
870impl Hash for HigherOrderUDF {
871 fn hash<H: Hasher>(&self, state: &mut H) {
872 self.inner.dyn_hash(state)
873 }
874}
875
876impl HigherOrderUDF {
877 /// Create a new `HigherOrderUDF` from a [`HigherOrderUDFImpl`] trait object.
878 ///
879 /// Note this is the same as using the `From` impl (`HigherOrderUDF::from`).
880 pub fn new_from_impl<F>(fun: F) -> HigherOrderUDF
881 where
882 F: HigherOrderUDFImpl + 'static,
883 {
884 Self::new_from_shared_impl(Arc::new(fun))
885 }
886
887 /// Create a new `HigherOrderUDF` from a shared [`HigherOrderUDFImpl`] trait object.
888 pub fn new_from_shared_impl(fun: Arc<dyn HigherOrderUDFImpl>) -> HigherOrderUDF {
889 Self { inner: fun }
890 }
891
892 /// Return the underlying [`HigherOrderUDFImpl`] trait object for this function.
893 pub fn inner(&self) -> &Arc<dyn HigherOrderUDFImpl> {
894 &self.inner
895 }
896
897 /// Adds additional names that can be used to invoke this function, in
898 /// addition to `name`.
899 ///
900 /// If you implement [`HigherOrderUDFImpl`] directly you should return aliases
901 /// directly.
902 pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
903 Self::new_from_impl(AliasedHigherOrderUDFImpl::new(
904 Arc::clone(&self.inner),
905 aliases,
906 ))
907 }
908
909 /// Returns this function's name.
910 ///
911 /// See [`HigherOrderUDFImpl::name`] for more details.
912 pub fn name(&self) -> &str {
913 self.inner.name()
914 }
915
916 /// Returns the aliases for this function.
917 ///
918 /// See [`HigherOrderUDF::with_aliases`] for more details.
919 pub fn aliases(&self) -> &[String] {
920 self.inner.aliases()
921 }
922
923 /// Returns this function's schema_name.
924 ///
925 /// See [`HigherOrderUDFImpl::schema_name`] for more details.
926 pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
927 self.inner.schema_name(args)
928 }
929
930 /// Returns this function's [`HigherOrderSignature`].
931 pub fn signature(&self) -> &HigherOrderSignature {
932 self.inner.signature()
933 }
934
935 /// Returns the parameters of all lambdas of this function for the current step.
936 ///
937 /// See [`HigherOrderUDFImpl::lambda_parameters`] for more details.
938 pub fn lambda_parameters(
939 &self,
940 step: usize,
941 fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
942 ) -> Result<LambdaParametersProgress> {
943 self.inner.lambda_parameters(step, fields)
944 }
945
946 /// Coerce value arguments based on lambda output types.
947 ///
948 /// See [`HigherOrderUDFImpl::coerce_values_for_lambdas`] for more details.
949 pub fn coerce_values_for_lambdas(
950 &self,
951 fields: &[ValueOrLambda<DataType, DataType>],
952 ) -> Result<Option<Vec<DataType>>> {
953 self.inner.coerce_values_for_lambdas(fields)
954 }
955
956 /// Returns the return field of the function given its arguments.
957 ///
958 /// See [`HigherOrderUDFImpl::return_field_from_args`] for more details.
959 pub fn return_field_from_args(
960 &self,
961 args: HigherOrderReturnFieldArgs,
962 ) -> Result<FieldRef> {
963 self.inner.return_field_from_args(args)
964 }
965
966 /// Whether List or LargeList arguments should have non-empty null sublists
967 /// cleaned before invoking this function.
968 pub fn clear_null_values(&self) -> bool {
969 self.inner.clear_null_values()
970 }
971
972 /// Invoke the function returning the appropriate result.
973 ///
974 /// See [`HigherOrderUDFImpl::invoke_with_args`] for more details.
975 pub fn invoke_with_args(
976 &self,
977 args: HigherOrderFunctionArgs,
978 ) -> Result<ColumnarValue> {
979 self.inner.invoke_with_args(args)
980 }
981
982 /// Returns true if some of this function's subexpressions may not be evaluated.
983 ///
984 /// See [`HigherOrderUDFImpl::short_circuits`] for more details.
985 pub fn short_circuits(&self) -> bool {
986 self.inner.short_circuits()
987 }
988
989 /// Returns which arguments are evaluated eagerly vs lazily.
990 ///
991 /// See [`HigherOrderUDFImpl::conditional_arguments`] for more details.
992 pub fn conditional_arguments<'a>(
993 &self,
994 args: &'a [Expr],
995 ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
996 self.inner.conditional_arguments(args)
997 }
998
999 /// Coerce value arguments of a function call to types that the function can evaluate.
1000 ///
1001 /// See [`HigherOrderUDFImpl::coerce_value_types`] for more details.
1002 pub fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1003 self.inner.coerce_value_types(arg_types)
1004 }
1005
1006 /// Returns the documentation for this function, if any.
1007 pub fn documentation(&self) -> Option<&Documentation> {
1008 self.inner.documentation()
1009 }
1010}
1011
1012impl<F> From<F> for HigherOrderUDF
1013where
1014 F: HigherOrderUDFImpl + 'static,
1015{
1016 fn from(fun: F) -> Self {
1017 Self::new_from_impl(fun)
1018 }
1019}
1020
1021/// `HigherOrderUDFImpl` that adds aliases to the underlying function. It is
1022/// better to implement [`HigherOrderUDFImpl`], which supports aliases, directly
1023/// if possible.
1024#[derive(Debug, PartialEq, Eq, Hash)]
1025struct AliasedHigherOrderUDFImpl {
1026 inner: UdfEq<Arc<dyn HigherOrderUDFImpl>>,
1027 aliases: Vec<String>,
1028}
1029
1030impl AliasedHigherOrderUDFImpl {
1031 fn new(
1032 inner: Arc<dyn HigherOrderUDFImpl>,
1033 new_aliases: impl IntoIterator<Item = &'static str>,
1034 ) -> Self {
1035 let mut aliases = inner.aliases().to_vec();
1036 aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1037 Self {
1038 inner: inner.into(),
1039 aliases,
1040 }
1041 }
1042}
1043
1044#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method
1045impl HigherOrderUDFImpl for AliasedHigherOrderUDFImpl {
1046 fn name(&self) -> &str {
1047 self.inner.name()
1048 }
1049
1050 fn aliases(&self) -> &[String] {
1051 &self.aliases
1052 }
1053
1054 fn schema_name(&self, args: &[Expr]) -> Result<String> {
1055 self.inner.schema_name(args)
1056 }
1057
1058 fn signature(&self) -> &HigherOrderSignature {
1059 self.inner.signature()
1060 }
1061
1062 fn lambda_parameters(
1063 &self,
1064 step: usize,
1065 fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
1066 ) -> Result<LambdaParametersProgress> {
1067 self.inner.lambda_parameters(step, fields)
1068 }
1069
1070 fn coerce_values_for_lambdas(
1071 &self,
1072 fields: &[ValueOrLambda<DataType, DataType>],
1073 ) -> Result<Option<Vec<DataType>>> {
1074 self.inner.coerce_values_for_lambdas(fields)
1075 }
1076
1077 fn return_field_from_args(
1078 &self,
1079 args: HigherOrderReturnFieldArgs,
1080 ) -> Result<FieldRef> {
1081 self.inner.return_field_from_args(args)
1082 }
1083
1084 fn clear_null_values(&self) -> bool {
1085 self.inner.clear_null_values()
1086 }
1087
1088 fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
1089 self.inner.invoke_with_args(args)
1090 }
1091
1092 fn short_circuits(&self) -> bool {
1093 self.inner.short_circuits()
1094 }
1095
1096 fn conditional_arguments<'a>(
1097 &self,
1098 args: &'a [Expr],
1099 ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
1100 self.inner.conditional_arguments(args)
1101 }
1102
1103 fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1104 self.inner.coerce_value_types(arg_types)
1105 }
1106
1107 fn documentation(&self) -> Option<&Documentation> {
1108 self.inner.documentation()
1109 }
1110}
1111
1112pub(crate) fn resolve_lambda_variables(
1113 expr: Expr,
1114 schema: &DFSchema,
1115 // a map of lambda variable name => a never empty stack of fields [ [..shadowed], in_scope ]
1116 vars: &mut HashMap<String, Vec<FieldRef>>,
1117) -> Result<Transformed<Expr>> {
1118 expr.transform_down(|expr| match expr {
1119 Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => {
1120 // not inlined to reduce nesting
1121 resolve_higher_order_function(func, args, schema, vars)
1122 }
1123 Expr::LambdaVariable(mut var) => {
1124 let field_stack = vars.get(&var.name).ok_or_else(|| {
1125 plan_datafusion_err!(
1126 "missing field of lambda variable {} while resolving",
1127 var.name
1128 )
1129 })?;
1130
1131 let field = field_stack.last().ok_or_else(|| {
1132 internal_datafusion_err!("every entry should have at least one field")
1133 })?;
1134
1135 let field = Arc::clone(field).renamed(&var.name);
1136
1137 let transformed = var.field.as_ref().is_none_or(|old| old != &field);
1138
1139 var.field = Some(field);
1140
1141 Ok(Transformed::new_transformed(
1142 Expr::LambdaVariable(var),
1143 transformed,
1144 ))
1145 }
1146 _ => Ok(Transformed::no(expr)),
1147 })
1148}
1149
1150fn resolve_higher_order_function(
1151 func: Arc<HigherOrderUDF>,
1152 args: Vec<Expr>,
1153 schema: &DFSchema,
1154 // a map of lambda variable name => a never empty stack of fields [ [..shadowed], in_scope ]
1155 vars: &mut HashMap<String, Vec<FieldRef>>,
1156) -> Result<Transformed<Expr>> {
1157 let args = if !vars.is_empty() {
1158 /* if this is a nested lambda, we must resolve non-lambda args before invoking
1159 lambda_parameters because it will invoke ExprSchemable::to_field for every
1160 non-lambda parameter, and if one them contains a lambda variable, it will fail
1161 due to it being unresolved. Example query:
1162
1163 array_transform([[1, 2]], a -> array_transform(a, b -> b+1))
1164
1165 the nested array_transform's lambda_parameters will call Lambdavariable::to_field
1166 on it's first argument, the variable `a`, which must be resolved
1167 */
1168 args.map_elements(|arg| match arg {
1169 Expr::Lambda(_) => Ok(Transformed::no(arg)),
1170 _ => resolve_lambda_variables(arg, schema, vars),
1171 })?
1172 } else {
1173 Transformed::no(args)
1174 };
1175
1176 let transformed = args.transformed;
1177 let mut args = args.data;
1178
1179 let current_fields = args
1180 .iter()
1181 .map(|e| match e {
1182 Expr::Lambda(_lambda_function) => Ok(ValueOrLambda::Lambda(None)),
1183 _ => Ok(ValueOrLambda::Value(e.to_field(schema)?.1)),
1184 })
1185 .collect::<Result<Vec<_>>>()?;
1186
1187 // coerce fields because coercion may alter the lambda parameters
1188 let mut fields = value_fields_with_higher_order_udf(¤t_fields, func.as_ref())?;
1189
1190 let num_lambdas = args.iter().filter(|a| matches!(a, Expr::Lambda(_))).count();
1191
1192 let mut step = 0;
1193
1194 let lambda_params = loop {
1195 match func.lambda_parameters(step, &fields)? {
1196 LambdaParametersProgress::Partial(params) => {
1197 let mut params = params.into_iter();
1198
1199 if params.len() != num_lambdas {
1200 return plan_err!(
1201 "{} lambda_parameters returned {} lambdas but {num_lambdas} expected",
1202 func.name(),
1203 params.len()
1204 );
1205 }
1206
1207 for (arg, field) in std::iter::zip(&mut args, &mut fields) {
1208 match (arg, field) {
1209 (Expr::Lambda(lambda), ValueOrLambda::Lambda(field)) => {
1210 let params = params.next().ok_or_else(|| {
1211 internal_datafusion_err!(
1212 "params len should have been checked above"
1213 )
1214 })?;
1215
1216 if let Some(params) = params {
1217 for (name, field) in
1218 std::iter::zip(&lambda.params, params)
1219 {
1220 vars.entry_ref(name)
1221 .or_default()
1222 .push(field.renamed(name.as_str()));
1223 }
1224
1225 let body_with_vars = resolve_lambda_variables(
1226 mem::take(lambda.body.as_mut()),
1227 schema,
1228 vars,
1229 )?;
1230
1231 remove_scope(vars, &lambda.params)?;
1232
1233 *field = Some(body_with_vars.data.to_field(schema)?.1);
1234 *lambda.body = body_with_vars.data;
1235 }
1236 }
1237 (_, ValueOrLambda::Lambda(_)) => {
1238 return internal_err!(
1239 "value_fields_with_higher_order_udf returned a value for a lambda argument"
1240 );
1241 }
1242 (Expr::Lambda(_), ValueOrLambda::Value(_)) => {
1243 return internal_err!(
1244 "value_fields_with_higher_order_udf returned a lambda for a value argument"
1245 );
1246 }
1247 (_, ValueOrLambda::Value(_)) => {} // nothing to do
1248 }
1249 }
1250 }
1251 LambdaParametersProgress::Complete(params) => break params,
1252 }
1253
1254 let limit = func.signature().lambda_parameters_max_iterations;
1255
1256 step += 1;
1257
1258 if step > limit {
1259 return plan_err!(
1260 "{} lambda_parameters called {limit} times without completion",
1261 func.name()
1262 );
1263 }
1264 };
1265
1266 let mut lambda_params = lambda_params.into_iter();
1267
1268 if num_lambdas != lambda_params.len() {
1269 return plan_err!(
1270 "{} lambda_parameters returned {} values for {num_lambdas} lambdas",
1271 func.name(),
1272 lambda_params.len()
1273 );
1274 }
1275
1276 let args = args.map_elements(|arg| match arg {
1277 Expr::Lambda(mut lambda) => {
1278 let lambda_params = lambda_params.next().ok_or_else(|| {
1279 internal_datafusion_err!(
1280 "lambda_params len should have been checked above"
1281 )
1282 })?;
1283
1284 if lambda.params.len() > lambda_params.len() {
1285 return plan_err!(
1286 "{} lambda defined {} params ({}), but only {} supported",
1287 func.name(),
1288 lambda.params.len(),
1289 display_comma_separated(&lambda.params),
1290 lambda_params.len()
1291 );
1292 }
1293
1294 if !all_unique(&lambda.params) {
1295 return plan_err!(
1296 "lambda params must be unique, got ({})",
1297 lambda.params.join(", ")
1298 );
1299 }
1300
1301 for (param, field) in std::iter::zip(&lambda.params, lambda_params) {
1302 vars.entry_ref(param)
1303 .or_default()
1304 .push(field.renamed(param.as_str()));
1305 }
1306
1307 let transformed =
1308 resolve_lambda_variables(mem::take(lambda.body.as_mut()), schema, vars)?;
1309
1310 *lambda.body = transformed.data;
1311
1312 remove_scope(vars, &lambda.params)?;
1313
1314 Ok(Transformed::new(
1315 Expr::Lambda(lambda),
1316 transformed.transformed,
1317 TreeNodeRecursion::Jump,
1318 ))
1319 }
1320 arg => Ok(Transformed::no(arg)), // resolved at the start of the function
1321 })?;
1322
1323 Ok(Transformed::new(
1324 Expr::HigherOrderFunction(HigherOrderFunction::new(func, args.data)),
1325 transformed || args.transformed,
1326 TreeNodeRecursion::Jump,
1327 ))
1328}
1329
1330fn remove_scope(
1331 vars: &mut HashMap<String, Vec<FieldRef>>,
1332 scope: &[String],
1333) -> Result<()> {
1334 for param in scope {
1335 match vars.entry_ref(param) {
1336 EntryRef::Occupied(mut v) => {
1337 if v.get().len() == 1 {
1338 v.remove();
1339 } else {
1340 v.get_mut().pop().ok_or_else(|| {
1341 internal_datafusion_err!(
1342 "every entry should have at least one field"
1343 )
1344 })?;
1345 }
1346 }
1347 EntryRef::Vacant(_v) => {
1348 return internal_err!("no empty value should be in the map");
1349 }
1350 }
1351 }
1352
1353 Ok(())
1354}
1355
1356fn all_unique(params: &[String]) -> bool {
1357 match params.len() {
1358 0 | 1 => true,
1359 2 => params[0] != params[1],
1360 _ => {
1361 let mut set = HashSet::with_capacity(params.len());
1362
1363 params.iter().all(|p| set.insert(p.as_str()))
1364 }
1365 }
1366}
1367
1368#[cfg(test)]
1369mod tests {
1370 use super::*;
1371 use std::hash::DefaultHasher;
1372 use std::sync::Arc;
1373
1374 use arrow_schema::{DataType, Field, FieldRef, Schema};
1375 use datafusion_common::{DFSchema, Result};
1376 use datafusion_expr_common::columnar_value::ColumnarValue;
1377 use datafusion_expr_common::signature::Volatility;
1378
1379 use crate::{
1380 Expr, HigherOrderSignature, HigherOrderUDF, HigherOrderUDFImpl,
1381 LambdaParametersProgress, ValueOrLambda, col,
1382 expr::{HigherOrderFunction, LambdaVariable},
1383 lambda, lambda_var, lit,
1384 };
1385
1386 #[derive(Debug, PartialEq, Eq, Hash)]
1387 struct TestHigherOrderUDF {
1388 name: &'static str,
1389 field: &'static str,
1390 signature: HigherOrderSignature,
1391 }
1392 impl HigherOrderUDFImpl for TestHigherOrderUDF {
1393 fn name(&self) -> &str {
1394 self.name
1395 }
1396
1397 fn signature(&self) -> &HigherOrderSignature {
1398 &self.signature
1399 }
1400
1401 fn lambda_parameters(
1402 &self,
1403 _step: usize,
1404 _fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
1405 ) -> Result<LambdaParametersProgress> {
1406 unimplemented!()
1407 }
1408
1409 fn return_field_from_args(
1410 &self,
1411 _args: HigherOrderReturnFieldArgs,
1412 ) -> Result<FieldRef> {
1413 unimplemented!()
1414 }
1415
1416 fn invoke_with_args(
1417 &self,
1418 _args: HigherOrderFunctionArgs,
1419 ) -> Result<ColumnarValue> {
1420 unimplemented!()
1421 }
1422 }
1423
1424 // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd
1425 // must be consistent, so they are tested together.
1426 #[test]
1427 fn test_partial_eq_hash_and_partial_ord() {
1428 // A parameterized function
1429 let f = test_func("foo", "a");
1430
1431 // Same like `f`, different instance
1432 let f2 = test_func("foo", "a");
1433 assert_eq!(&f, &f2);
1434 assert_eq!(hash(&f), hash(&f2));
1435 assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
1436
1437 // Different parameter
1438 let b = test_func("foo", "b");
1439 assert_ne!(&f, &b);
1440 assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test
1441 assert_eq!(f.partial_cmp(&b), None);
1442
1443 // Different name
1444 let o = test_func("other", "a");
1445 assert_ne!(&f, &o);
1446 assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test
1447 assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
1448
1449 // Different name and parameter
1450 assert_ne!(&b, &o);
1451 assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test
1452 assert_eq!(b.partial_cmp(&o), Some(Ordering::Less));
1453 }
1454
1455 fn test_func(name: &'static str, parameter: &'static str) -> Arc<HigherOrderUDF> {
1456 Arc::new(HigherOrderUDF::new_from_impl(TestHigherOrderUDF {
1457 name,
1458 field: parameter,
1459 signature: HigherOrderSignature::variadic_any(Volatility::Immutable),
1460 }))
1461 }
1462
1463 fn hash<T: Hash>(value: &T) -> u64 {
1464 let hasher = &mut DefaultHasher::new();
1465 value.hash(hasher);
1466 hasher.finish()
1467 }
1468
1469 #[derive(Debug, PartialEq, Eq, Hash)]
1470 struct MockArrayReduce {
1471 signature: HigherOrderSignature,
1472 }
1473
1474 impl HigherOrderUDFImpl for MockArrayReduce {
1475 fn name(&self) -> &str {
1476 "array_reduce"
1477 }
1478
1479 fn aliases(&self) -> &[String] {
1480 &[]
1481 }
1482
1483 fn signature(&self) -> &HigherOrderSignature {
1484 &self.signature
1485 }
1486
1487 fn lambda_parameters(
1488 &self,
1489 step: usize,
1490 fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
1491 ) -> Result<LambdaParametersProgress> {
1492 // optional finish not supported for simplicity
1493 let [
1494 ValueOrLambda::Value(list),
1495 ValueOrLambda::Value(initial_value),
1496 ValueOrLambda::Lambda(merge),
1497 ValueOrLambda::Lambda(_finish),
1498 ] = fields
1499 else {
1500 unreachable!()
1501 };
1502
1503 let list_field = match list.data_type() {
1504 DataType::List(field) => field,
1505 _ => unreachable!(),
1506 };
1507
1508 Ok(match (step, merge) {
1509 (0, None) => {
1510 // at the first step, we use the initial_value as merge accumulator,
1511 // and return None for finish since we don't know the output of merge
1512 LambdaParametersProgress::Partial(vec![
1513 // merge
1514 Some(vec![Arc::clone(initial_value), Arc::clone(list_field)]),
1515 // finish
1516 None,
1517 ])
1518 }
1519 (1, Some(accumulator)) | (0, Some(accumulator)) => {
1520 // now we can use the merge output as it's accumulator and
1521 // as the finish parameter
1522 LambdaParametersProgress::Complete(vec![
1523 // merge
1524 vec![Arc::clone(accumulator), Arc::clone(list_field)],
1525 // finish
1526 vec![Arc::clone(accumulator)],
1527 ])
1528 }
1529 (1, None) => {
1530 unreachable!()
1531 }
1532 _ => unreachable!(),
1533 })
1534 }
1535
1536 fn return_field_from_args(
1537 &self,
1538 args: HigherOrderReturnFieldArgs,
1539 ) -> Result<FieldRef> {
1540 // optional finish not supported for simplicity
1541 let [
1542 ValueOrLambda::Value(_list),
1543 ValueOrLambda::Value(_initial_value),
1544 ValueOrLambda::Lambda(_merge),
1545 ValueOrLambda::Lambda(finish),
1546 ] = args.arg_fields
1547 else {
1548 unreachable!()
1549 };
1550
1551 Ok(Arc::clone(finish))
1552 }
1553
1554 fn invoke_with_args(
1555 &self,
1556 _args: HigherOrderFunctionArgs,
1557 ) -> Result<ColumnarValue> {
1558 unreachable!()
1559 }
1560 }
1561
1562 #[test]
1563 fn test_resolve_lambda_variables() {
1564 let schema = DFSchema::try_from(Schema::new(vec![Field::new(
1565 "c",
1566 DataType::new_list(DataType::new_list(DataType::Int32, true), true),
1567 true,
1568 )]))
1569 .unwrap();
1570
1571 let func = Arc::new(HigherOrderUDF::new_from_impl(MockArrayReduce {
1572 signature: HigherOrderSignature::variadic_any(Volatility::Immutable),
1573 }));
1574
1575 /*
1576 array_reduce(
1577 c,
1578 0,
1579 (acc1, v) -> acc + array_reduce(
1580 v,
1581 0,
1582 (acc2, v) -> acc2 + acc1 + v,
1583 reduced -> reduced * 2.0
1584 ),
1585 reduced -> reduced * 2
1586 )
1587 */
1588 let expr = Expr::HigherOrderFunction(HigherOrderFunction::new(
1589 Arc::clone(&func),
1590 vec![
1591 col("c"),
1592 lit(0),
1593 lambda(
1594 ["acc1", "v"],
1595 lambda_var("acc1")
1596 + Expr::HigherOrderFunction(HigherOrderFunction::new(
1597 Arc::clone(&func),
1598 vec![
1599 lambda_var("v"),
1600 lit(0),
1601 lambda(
1602 ["acc2", "v"],
1603 lambda_var("acc2")
1604 + lambda_var("acc1")
1605 + lambda_var("v"),
1606 ),
1607 lambda(["reduced"], lambda_var("reduced") * lit(2.0)),
1608 ],
1609 )),
1610 ),
1611 lambda(["reduced"], lambda_var("reduced") * lit(2)),
1612 ],
1613 ));
1614
1615 let resolved_expr = expr.resolve_lambda_variables(&schema).unwrap().data;
1616
1617 /*
1618 array_reduce(
1619 c@[[Int32]],
1620 0@Int64,
1621 (acc1@Float64, v@[Int32]) -> acc@Float64 + array_reduce(
1622 v@[Int32],
1623 0@Int64,
1624 (acc2@Float64, v@Int32) -> acc2@Float64 + acc1@Float64 + v@Int32,
1625 reducedFloat64 -> reduced@Float64 * 2.0@Float64
1626 ),
1627 reduced@Float64 -> reduced@Float64 * 2@Int64
1628 )
1629 */
1630 let expected = Expr::HigherOrderFunction(HigherOrderFunction::new(
1631 Arc::clone(&func),
1632 vec![
1633 col("c"),
1634 lit(0),
1635 lambda(
1636 ["acc1", "v"],
1637 resolved_lambda_var("acc1", DataType::Float64, true)
1638 + Expr::HigherOrderFunction(HigherOrderFunction::new(
1639 Arc::clone(&func),
1640 vec![
1641 resolved_lambda_var(
1642 "v",
1643 DataType::new_list(DataType::Int32, true),
1644 true,
1645 ),
1646 lit(0),
1647 lambda(
1648 ["acc2", "v"],
1649 resolved_lambda_var("acc2", DataType::Float64, true)
1650 + resolved_lambda_var(
1651 "acc1",
1652 DataType::Float64,
1653 true,
1654 )
1655 + resolved_lambda_var("v", DataType::Int32, true),
1656 ),
1657 lambda(
1658 ["reduced"],
1659 resolved_lambda_var(
1660 "reduced",
1661 DataType::Float64,
1662 true,
1663 ) * lit(2.0),
1664 ),
1665 ],
1666 )),
1667 ),
1668 lambda(
1669 ["reduced"],
1670 resolved_lambda_var("reduced", DataType::Float64, true) * lit(2),
1671 ),
1672 ],
1673 ));
1674
1675 assert_eq!(resolved_expr, expected);
1676 }
1677
1678 fn resolved_lambda_var(name: &str, dt: DataType, nullable: bool) -> Expr {
1679 Expr::LambdaVariable(LambdaVariable::new(
1680 name.into(),
1681 Some(Arc::new(Field::new(name, dt, nullable))),
1682 ))
1683 }
1684}