1use std::fmt::{self, Debug, Formatter};
33use std::hash::{Hash, Hasher};
34use std::sync::Arc;
35
36use crate::PhysicalExpr;
37use crate::expressions::{LambdaExpr, Literal};
38
39use arrow::array::{Array, RecordBatch};
40use arrow::datatypes::{DataType, FieldRef, Schema};
41use datafusion_common::config::{ConfigEntry, ConfigOptions};
42use datafusion_common::datatype::FieldExt;
43use datafusion_common::utils::remove_list_null_values;
44use datafusion_common::{
45 Result, ScalarValue, exec_err, internal_datafusion_err, internal_err,
46 plan_datafusion_err, plan_err,
47};
48use datafusion_expr::type_coercion::functions::value_fields_with_higher_order_udf;
49use datafusion_expr::{
50 ColumnarValue, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, HigherOrderUDF,
51 LambdaArgument, LambdaParametersProgress, ValueOrLambda, Volatility, expr_vec_fmt,
52};
53
54enum ArgSlot {
61 Value,
63 Lambda(Arc<LambdaExpr>),
67}
68
69pub struct HigherOrderFunctionExpr {
71 fun: Arc<HigherOrderUDF>,
73 name: String,
75 args: Vec<Arc<dyn PhysicalExpr>>,
91 slots: Vec<ArgSlot>,
95 return_field: FieldRef,
100 config_options: Arc<ConfigOptions>,
102}
103
104impl Debug for HigherOrderFunctionExpr {
105 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
106 let lambda_positions: Vec<_> = self
107 .slots
108 .iter()
109 .enumerate()
110 .filter_map(|(i, slot)| matches!(slot, ArgSlot::Lambda(_)).then_some(i))
111 .collect();
112 f.debug_struct("HigherOrderFunctionExpr")
113 .field("fun", &"<FUNC>")
114 .field("name", &self.name)
115 .field("args", &self.args)
116 .field("lambda_positions", &lambda_positions)
117 .field("return_field", &self.return_field)
118 .finish()
119 }
120}
121
122impl HigherOrderFunctionExpr {
123 pub fn try_new_with_schema(
128 fun: Arc<HigherOrderUDF>,
129 args: Vec<Arc<dyn PhysicalExpr>>,
130 schema: &Schema,
131 config_options: Arc<ConfigOptions>,
132 ) -> Result<Self> {
133 let name = fun.name().to_string();
134 let mut slots = Vec::with_capacity(args.len());
135 let arg_fields = args
136 .iter()
137 .map(|e| match e.downcast_ref::<LambdaExpr>() {
138 Some(lambda) => {
139 slots.push(ArgSlot::Lambda(Arc::new(lambda.clone())));
140 Ok(ValueOrLambda::Lambda(lambda.body().return_field(schema)?))
141 }
142 None => {
143 slots.push(ArgSlot::Value);
144 Ok(ValueOrLambda::Value(e.return_field(schema)?))
145 }
146 })
147 .collect::<Result<Vec<_>>>()?;
148
149 value_fields_with_higher_order_udf(&arg_fields, fun.as_ref())?;
151
152 let arguments = args
153 .iter()
154 .map(|e| e.downcast_ref::<Literal>().map(|literal| literal.value()))
155 .collect::<Vec<_>>();
156
157 let ret_args = HigherOrderReturnFieldArgs {
158 arg_fields: &arg_fields,
159 scalar_arguments: &arguments,
160 };
161
162 let return_field = fun.return_field_from_args(ret_args)?;
163
164 Ok(Self {
165 fun,
166 name,
167 args,
168 slots,
169 return_field,
170 config_options,
171 })
172 }
173
174 pub fn fun(&self) -> &HigherOrderUDF {
176 self.fun.as_ref()
177 }
178
179 pub fn name(&self) -> &str {
181 &self.name
182 }
183
184 pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
186 &self.args
187 }
188
189 pub fn return_type(&self) -> &DataType {
191 self.return_field.data_type()
192 }
193
194 pub fn nullable(&self) -> bool {
195 self.return_field.is_nullable()
196 }
197
198 pub fn config_options(&self) -> &ConfigOptions {
199 &self.config_options
200 }
201
202 fn resolve_lambda_parameters(
206 &self,
207 fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
208 ) -> Result<Vec<Vec<FieldRef>>> {
209 let num_lambdas = self
210 .slots
211 .iter()
212 .filter(|s| matches!(s, ArgSlot::Lambda(_)))
213 .count();
214 if num_lambdas == 0 {
215 return Ok(Vec::new());
216 }
217 match self.fun().lambda_parameters(0, fields)? {
218 LambdaParametersProgress::Partial(_) => plan_err!(
219 "{} lambda_parameters returned a partial result when the return type of all it's lambdas were provided",
220 self.name()
221 ),
222 LambdaParametersProgress::Complete(items) => {
223 if items.len() < num_lambdas {
228 return exec_err!(
229 "{} invocation defined {num_lambdas} but lambda_parameters returned only {}",
230 self.name(),
231 items.len()
232 );
233 }
234 Ok(items)
235 }
236 }
237 }
238}
239
240impl fmt::Display for HigherOrderFunctionExpr {
241 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
242 write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
243 }
244}
245
246impl PartialEq for HigherOrderFunctionExpr {
247 fn eq(&self, o: &Self) -> bool {
248 if std::ptr::eq(self, o) {
249 return true;
251 }
252 let Self {
255 fun,
256 name,
257 args,
258 slots: _,
259 return_field,
260 config_options,
261 } = self;
262 fun.eq(&o.fun)
263 && name.eq(&o.name)
264 && args.eq(&o.args)
265 && return_field.eq(&o.return_field)
266 && (Arc::ptr_eq(config_options, &o.config_options)
267 || sorted_config_entries(config_options)
268 == sorted_config_entries(&o.config_options))
269 }
270}
271impl Eq for HigherOrderFunctionExpr {}
272impl Hash for HigherOrderFunctionExpr {
273 fn hash<H: Hasher>(&self, state: &mut H) {
274 let Self {
275 fun,
276 name,
277 args,
278 slots: _,
279 return_field,
280 config_options: _, } = self;
282 fun.hash(state);
283 name.hash(state);
284 args.hash(state);
285 return_field.hash(state);
286 }
287}
288
289fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
290 let mut entries = config_options.entries();
291 entries.sort_by(|l, r| l.key.cmp(&r.key));
292 entries
293}
294
295impl PhysicalExpr for HigherOrderFunctionExpr {
296 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
297 let mut arg_fields = Vec::with_capacity(self.args.len());
298 let mut fields = Vec::with_capacity(self.args.len());
299 for (arg, slot) in self.args.iter().zip(&self.slots) {
300 match slot {
301 ArgSlot::Lambda(lambda) => {
302 let field = lambda.body().return_field(batch.schema_ref())?;
303 arg_fields.push(ValueOrLambda::Lambda(Arc::clone(&field)));
304 fields.push(ValueOrLambda::Lambda(Some(field)));
305 }
306 ArgSlot::Value => {
307 let field = arg.return_field(batch.schema_ref())?;
308 arg_fields.push(ValueOrLambda::Value(Arc::clone(&field)));
309 fields.push(ValueOrLambda::Value(field));
310 }
311 }
312 }
313
314 let mut lambda_parameters = self.resolve_lambda_parameters(&fields)?.into_iter();
315
316 let args = self
317 .args
318 .iter()
319 .zip(&self.slots)
320 .map(|(arg, slot)| match slot {
321 ArgSlot::Lambda(lambda) => {
322 let lambda_params = lambda_parameters.next().ok_or_else(|| {
323 internal_datafusion_err!(
324 "params len should have been checked above"
325 )
326 })?;
327
328 if lambda.params().len() > lambda_params.len() {
329 return exec_err!(
330 "lambda defined {} params but higher-order function support only {}",
331 lambda.params().len(),
332 lambda_params.len()
333 );
334 }
335
336 let params = std::iter::zip(lambda.params(), lambda_params)
337 .map(|(name, param)| param.renamed(name.as_str()))
338 .collect();
339
340 let projection = lambda
342 .projection()
343 .iter()
344 .copied()
345 .filter(|i| *i < batch.num_columns())
346 .collect::<Vec<_>>();
347
348 Ok(ValueOrLambda::Lambda(LambdaArgument::new(
349 params,
350 Arc::clone(lambda.projected_body()),
351 if projection.is_empty() {
352 None
353 } else {
354 Some(batch.project(&projection)?)
355 },
356 )))
357 }
358 ArgSlot::Value => {
359 let value = arg.evaluate(batch)?;
360
361 let value = if self.fun.clear_null_values()
362 && matches!(
363 value.data_type(),
364 DataType::List(_) | DataType::LargeList(_)
365 )
366 {
367 let arr = value.into_array(batch.num_rows())?;
368 if arr.null_count() == 0 {
369 ColumnarValue::Array(arr)
370 } else {
371 ColumnarValue::Array(remove_list_null_values(&arr)?)
372 }
373 } else {
374 value
375 };
376
377 Ok(ValueOrLambda::Value(value))
378 }
379 })
380 .collect::<Result<Vec<_>>>()?;
381
382 let input_empty = args.is_empty();
383 let input_all_scalar = args
384 .iter()
385 .all(|arg| matches!(arg, ValueOrLambda::Value(ColumnarValue::Scalar(_))));
386
387 let output = self.fun.invoke_with_args(HigherOrderFunctionArgs {
389 args,
390 arg_fields,
391 number_rows: batch.num_rows(),
392 return_field: Arc::clone(&self.return_field),
393 config_options: Arc::clone(&self.config_options),
394 })?;
395
396 if let ColumnarValue::Array(array) = &output
397 && array.len() != batch.num_rows()
398 {
399 let preserve_scalar = array.len() == 1 && !input_empty && input_all_scalar;
402 return if preserve_scalar {
403 ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
404 } else {
405 internal_err!(
406 "higher-order function {} returned a different number of rows than expected. Expected: {}, Got: {}",
407 self.name,
408 batch.num_rows(),
409 array.len()
410 )
411 };
412 }
413 Ok(output)
414 }
415
416 fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
417 Ok(Arc::clone(&self.return_field))
418 }
419
420 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
421 self.args.iter().collect()
422 }
423
424 fn with_new_children(
425 self: Arc<Self>,
426 children: Vec<Arc<dyn PhysicalExpr>>,
427 ) -> Result<Arc<dyn PhysicalExpr>> {
428 if children.len() != self.args.len() {
429 return internal_err!(
430 "HigherOrderFunctionExpr expects exactly {} child, got {}",
431 self.args.len(),
432 children.len()
433 );
434 }
435
436 let mut new_slots = Vec::with_capacity(children.len());
439 for (i, child) in children.iter().enumerate() {
440 match &self.slots[i] {
441 ArgSlot::Lambda(_) => {
442 let lambda = wrapped_lambda(child).ok_or_else(|| {
443 plan_datafusion_err!(
444 "{} unable to unwrap lambda from {} at position {i}",
445 &children[i],
446 self.name()
447 )
448 })?;
449 new_slots.push(ArgSlot::Lambda(Arc::new(lambda.clone())));
450 }
451 ArgSlot::Value => {
452 if child.is::<LambdaExpr>() {
453 return plan_err!(
454 "{} received a lambda via with_new_children at position {i} that wasn't a lambda before",
455 self.name()
456 );
457 }
458 new_slots.push(ArgSlot::Value);
459 }
460 }
461 }
462
463 Ok(Arc::new(HigherOrderFunctionExpr {
464 name: self.name.clone(),
465 fun: Arc::clone(&self.fun),
466 args: children,
467 slots: new_slots,
468 return_field: Arc::clone(&self.return_field),
469 config_options: Arc::clone(&self.config_options),
470 }))
471 }
472
473 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
474 write!(f, "{}(", self.name)?;
475 for (i, expr) in self.args.iter().enumerate() {
476 if i > 0 {
477 write!(f, ", ")?;
478 }
479 expr.fmt_sql(f)?;
480 }
481 write!(f, ")")
482 }
483
484 fn is_volatile_node(&self) -> bool {
485 self.fun.signature().volatility == Volatility::Volatile
486 }
487}
488
489fn wrapped_lambda(expr: &Arc<dyn PhysicalExpr>) -> Option<&LambdaExpr> {
490 let mut current = expr;
491
492 loop {
493 if let Some(lambda) = current.downcast_ref::<LambdaExpr>() {
494 return Some(lambda);
495 } else if current.is::<HigherOrderFunctionExpr>() {
496 return None;
497 }
498
499 match current.children().as_slice() {
500 [single_child] => current = *single_child,
501 _ => return None,
502 }
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use std::sync::Arc;
509
510 use super::*;
511 use crate::HigherOrderFunctionExpr;
512 use crate::expressions::Column;
513 use crate::expressions::NoOp;
514 use crate::expressions::lambda;
515 use crate::expressions::not;
516 use arrow::array::NullArray;
517 use arrow::array::RecordBatchOptions;
518 use arrow::datatypes::{DataType, Field, Schema};
519 use datafusion_common::Result;
520 use datafusion_common::assert_contains;
521 use datafusion_expr::{
522 HigherOrderFunctionArgs, HigherOrderSignature, HigherOrderUDF, HigherOrderUDFImpl,
523 };
524 use datafusion_expr_common::columnar_value::ColumnarValue;
525 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
526 use datafusion_physical_expr_common::physical_expr::is_volatile;
527
528 #[derive(Debug, PartialEq, Eq, Hash)]
530 struct MockHigherOrderUDF {
531 signature: HigherOrderSignature,
532 }
533
534 impl HigherOrderUDFImpl for MockHigherOrderUDF {
535 fn name(&self) -> &str {
536 "mock_function"
537 }
538
539 fn signature(&self) -> &HigherOrderSignature {
540 &self.signature
541 }
542
543 fn lambda_parameters(
544 &self,
545 _step: usize,
546 _fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
547 ) -> Result<LambdaParametersProgress> {
548 Ok(LambdaParametersProgress::Complete(vec![vec![Arc::new(
549 Field::new("", DataType::Null, true),
550 )]]))
551 }
552
553 fn return_field_from_args(
554 &self,
555 args: HigherOrderReturnFieldArgs,
556 ) -> Result<FieldRef> {
557 match &args.arg_fields[0] {
558 ValueOrLambda::Lambda(field) | ValueOrLambda::Value(field) => {
559 Ok(Arc::clone(field))
560 }
561 }
562 }
563
564 fn invoke_with_args(
565 &self,
566 args: HigherOrderFunctionArgs,
567 ) -> Result<ColumnarValue> {
568 match &args.args[0] {
569 ValueOrLambda::Lambda(lambda) => lambda.evaluate(
570 &[&|| Ok(Arc::new(NullArray::new(args.number_rows)))],
571 |arrays| Ok(arrays.to_vec()),
572 ),
573 ValueOrLambda::Value(value) => Ok(value.clone()),
574 }
575 }
576 }
577
578 #[test]
579 fn test_higher_order_function_volatile_node() {
580 let volatile_udf = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
582 signature: HigherOrderSignature::variadic_any(Volatility::Volatile),
583 }));
584
585 let stable_udf = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
587 signature: HigherOrderSignature::variadic_any(Volatility::Stable),
588 }));
589
590 let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
591 let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
592 let config_options = Arc::new(ConfigOptions::new());
593
594 let volatile_expr = HigherOrderFunctionExpr::try_new_with_schema(
596 volatile_udf,
597 args.clone(),
598 &schema,
599 Arc::clone(&config_options),
600 )
601 .unwrap();
602
603 assert!(volatile_expr.is_volatile_node());
604 let volatile_arc: Arc<dyn PhysicalExpr> = Arc::new(volatile_expr);
605 assert!(is_volatile(&volatile_arc));
606
607 let stable_expr = HigherOrderFunctionExpr::try_new_with_schema(
609 stable_udf,
610 args,
611 &schema,
612 config_options,
613 )
614 .unwrap();
615
616 assert!(!stable_expr.is_volatile_node());
617 let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
618 assert!(!is_volatile(&stable_arc));
619 }
620
621 #[test]
622 fn test_higher_order_function_wrapped_lambda() {
623 let fun = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
624 signature: HigherOrderSignature::variadic_any(Volatility::Stable),
625 }));
626
627 let expected = ScalarValue::Int32(Some(42));
628
629 let hof = HigherOrderFunctionExpr::try_new_with_schema(
630 fun,
631 vec![lambda(["a"], Arc::new(Literal::new(expected.clone()))).unwrap()],
632 &Schema::empty(),
633 Arc::new(ConfigOptions::new()),
634 )
635 .unwrap();
636
637 let new_children = vec![not(Arc::clone(&hof.args[0])).unwrap()];
638 let wrapped = Arc::new(hof).with_new_children(new_children).unwrap();
639
640 let result = wrapped
641 .evaluate(
642 &RecordBatch::try_new_with_options(
643 Arc::new(Schema::empty()),
644 vec![],
645 &RecordBatchOptions::new().with_row_count(Some(0)),
646 )
647 .unwrap(),
648 )
649 .unwrap();
650
651 let ColumnarValue::Scalar(result) = result else {
652 unreachable!()
653 };
654
655 assert_eq!(result, expected);
656 }
657
658 #[test]
659 fn test_higher_order_function_badly_wrapped_lambda() {
660 let fun = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
661 signature: HigherOrderSignature::variadic_any(Volatility::Stable),
662 }));
663
664 let hof = HigherOrderFunctionExpr::try_new_with_schema(
665 fun,
666 vec![
667 not(
668 lambda(["a"], Arc::new(Literal::new(ScalarValue::Int32(Some(42)))))
669 .unwrap(),
670 )
671 .unwrap(),
672 ],
673 &Schema::empty(),
674 Arc::new(ConfigOptions::new()),
675 )
676 .unwrap();
677
678 let result = hof
679 .evaluate(
680 &RecordBatch::try_new_with_options(
681 Arc::new(Schema::empty()),
682 vec![],
683 &RecordBatchOptions::new().with_row_count(Some(0)),
684 )
685 .unwrap(),
686 )
687 .unwrap_err();
688
689 assert_contains!(
690 result.to_string(),
691 "LambdaExpr::evaluate() should not be called"
692 );
693 }
694
695 #[test]
696 fn test_higher_order_function_unexpected_lambda() {
697 let fun = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF {
698 signature: HigherOrderSignature::variadic_any(Volatility::Stable),
699 }));
700
701 let hof = HigherOrderFunctionExpr::try_new_with_schema(
702 fun,
703 vec![Arc::new(NoOp::new())],
704 &Schema::empty(),
705 Arc::new(ConfigOptions::new()),
706 )
707 .unwrap();
708
709 let result = Arc::new(hof)
710 .with_new_children(vec![lambda(["a"], Arc::new(NoOp::new())).unwrap()])
711 .unwrap_err();
712
713 assert_contains!(
714 result.to_string(),
715 "mock_function received a lambda via with_new_children at position 0 that wasn't a lambda before"
716 );
717 }
718}