datafusion_functions/core/
greatest_least_utils.rs1use arrow::array::{Array, ArrayRef, BooleanArray};
19use arrow::compute::kernels::zip::zip;
20use arrow::datatypes::DataType;
21use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
22use datafusion_expr_common::columnar_value::ColumnarValue;
23use datafusion_expr_common::type_coercion::binary::type_union_resolution;
24use std::sync::Arc;
25
26pub(super) trait GreatestLeastOperator {
27 const NAME: &'static str;
28
29 fn keep_scalar<'a>(
30 lhs: &'a ScalarValue,
31 rhs: &'a ScalarValue,
32 ) -> Result<&'a ScalarValue>;
33
34 fn get_indexes_to_keep(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray>;
36}
37
38fn keep_array<Op: GreatestLeastOperator>(
39 lhs: ArrayRef,
40 rhs: ArrayRef,
41) -> Result<ArrayRef> {
42 let keep_lhs = Op::get_indexes_to_keep(lhs.as_ref(), rhs.as_ref())?;
44
45 let result = zip(&keep_lhs, &lhs, &rhs)?;
46
47 Ok(result)
48}
49
50pub(super) fn execute_conditional<Op: GreatestLeastOperator>(
51 args: &[ColumnarValue],
52) -> Result<ColumnarValue> {
53 if args.is_empty() {
54 return internal_err!(
55 "{} was called with no arguments. It requires at least 1.",
56 Op::NAME
57 );
58 }
59
60 if args.len() == 1 {
62 return Ok(args[0].clone());
63 }
64
65 let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x {
67 ColumnarValue::Scalar(_) => true,
68 ColumnarValue::Array(_) => false,
69 });
70
71 let mut arrays_iter = arrays.iter().map(|x| match x {
72 ColumnarValue::Array(a) => a,
73 _ => unreachable!(),
74 });
75
76 let first_array = arrays_iter.next();
77
78 let mut result: ArrayRef;
79
80 if !scalars.is_empty() {
82 let mut scalars_iter = scalars.iter().map(|x| match x {
83 ColumnarValue::Scalar(s) => s,
84 _ => unreachable!(),
85 });
86
87 let mut result_scalar = scalars_iter.next().unwrap();
89
90 for scalar in scalars_iter {
91 result_scalar = Op::keep_scalar(result_scalar, scalar)?;
92 }
93
94 if arrays.is_empty() {
96 return Ok(ColumnarValue::Scalar(result_scalar.clone()));
97 }
98
99 let first_array = first_array.unwrap();
101
102 result = keep_array::<Op>(
104 Arc::clone(first_array),
105 result_scalar.to_array_of_size(first_array.len())?,
106 )?;
107 } else {
108 result = Arc::clone(first_array.unwrap());
111 }
112
113 for array in arrays_iter {
114 result = keep_array::<Op>(Arc::clone(array), result)?;
115 }
116
117 Ok(ColumnarValue::Array(result))
118}
119
120pub(super) fn find_coerced_type<Op: GreatestLeastOperator>(
121 data_types: &[DataType],
122) -> Result<DataType> {
123 if data_types.is_empty() {
124 plan_err!(
125 "{} was called without any arguments. It requires at least 1.",
126 Op::NAME
127 )
128 } else if let Some(coerced_type) = type_union_resolution(data_types) {
129 Ok(coerced_type)
130 } else {
131 plan_err!("Cannot find a common type for arguments")
132 }
133}