datafusion_physical_expr_common/
datum.rs1use arrow::array::BooleanArray;
19use arrow::array::{make_comparator, ArrayRef, Datum};
20use arrow::buffer::NullBuffer;
21use arrow::compute::SortOptions;
22use arrow::error::ArrowError;
23use datafusion_common::DataFusionError;
24use datafusion_common::{arrow_datafusion_err, internal_err};
25use datafusion_common::{Result, ScalarValue};
26use datafusion_expr_common::columnar_value::ColumnarValue;
27use datafusion_expr_common::operator::Operator;
28use std::sync::Arc;
29
30pub fn apply(
34 lhs: &ColumnarValue,
35 rhs: &ColumnarValue,
36 f: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
37) -> Result<ColumnarValue> {
38 match (&lhs, &rhs) {
39 (ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
40 Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?))
41 }
42 (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok(
43 ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?),
44 ),
45 (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok(
46 ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?),
47 ),
48 (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
49 let array = f(&left.to_scalar()?, &right.to_scalar()?)?;
50 let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?;
51 Ok(ColumnarValue::Scalar(scalar))
52 }
53 }
54}
55
56pub fn apply_cmp(
58 lhs: &ColumnarValue,
59 rhs: &ColumnarValue,
60 f: impl Fn(&dyn Datum, &dyn Datum) -> Result<BooleanArray, ArrowError>,
61) -> Result<ColumnarValue> {
62 apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
63}
64
65pub fn apply_cmp_for_nested(
68 op: Operator,
69 lhs: &ColumnarValue,
70 rhs: &ColumnarValue,
71) -> Result<ColumnarValue> {
72 if matches!(
73 op,
74 Operator::Eq
75 | Operator::NotEq
76 | Operator::Lt
77 | Operator::Gt
78 | Operator::LtEq
79 | Operator::GtEq
80 | Operator::IsDistinctFrom
81 | Operator::IsNotDistinctFrom
82 ) {
83 apply(lhs, rhs, |l, r| {
84 Ok(Arc::new(compare_op_for_nested(op, l, r)?))
85 })
86 } else {
87 internal_err!("invalid operator for nested")
88 }
89}
90
91pub fn compare_with_eq(
93 lhs: &dyn Datum,
94 rhs: &dyn Datum,
95 is_nested: bool,
96) -> Result<BooleanArray> {
97 if is_nested {
98 compare_op_for_nested(Operator::Eq, lhs, rhs)
99 } else {
100 arrow::compute::kernels::cmp::eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e))
101 }
102}
103
104pub fn compare_op_for_nested(
106 op: Operator,
107 lhs: &dyn Datum,
108 rhs: &dyn Datum,
109) -> Result<BooleanArray> {
110 let (l, is_l_scalar) = lhs.get();
111 let (r, is_r_scalar) = rhs.get();
112 let l_len = l.len();
113 let r_len = r.len();
114
115 if l_len != r_len && !is_l_scalar && !is_r_scalar {
116 return internal_err!("len mismatch");
117 }
118
119 let len = match is_l_scalar {
120 true => r_len,
121 false => l_len,
122 };
123
124 if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom)
126 && (is_l_scalar && l.null_count() == 1 || is_r_scalar && r.null_count() == 1)
127 {
128 return Ok(BooleanArray::new_null(len));
129 }
130
131 let cmp = make_comparator(l, r, SortOptions::default())?;
134
135 let cmp_with_op = |i, j| match op {
136 Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(),
137 Operator::Lt => cmp(i, j).is_lt(),
138 Operator::Gt => cmp(i, j).is_gt(),
139 Operator::LtEq => !cmp(i, j).is_gt(),
140 Operator::GtEq => !cmp(i, j).is_lt(),
141 Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(),
142 _ => unreachable!("unexpected operator found"),
143 };
144
145 let values = match (is_l_scalar, is_r_scalar) {
146 (false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(),
147 (true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(),
148 (false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(),
149 (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(),
150 };
151
152 if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) {
155 Ok(BooleanArray::new(values, None))
156 } else {
157 let nulls = NullBuffer::union(l.nulls(), r.nulls());
160 Ok(BooleanArray::new(values, nulls))
161 }
162}