datafusion_physical_expr_common/
datum.rs1use arrow::array::BooleanArray;
19use arrow::array::{ArrayRef, Datum, make_comparator};
20use arrow::buffer::NullBuffer;
21use arrow::compute::kernels::cmp::{
22 distinct, eq, gt, gt_eq, lt, lt_eq, neq, not_distinct,
23};
24use arrow::compute::{SortOptions, ilike, like, nilike, nlike};
25use arrow::error::ArrowError;
26use datafusion_common::{Result, ScalarValue};
27use datafusion_common::{arrow_datafusion_err, assert_or_internal_err, internal_err};
28use datafusion_expr_common::columnar_value::ColumnarValue;
29use datafusion_expr_common::operator::Operator;
30use std::sync::Arc;
31
32pub fn apply(
36 lhs: &ColumnarValue,
37 rhs: &ColumnarValue,
38 f: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
39) -> Result<ColumnarValue> {
40 match (&lhs, &rhs) {
41 (ColumnarValue::Array(left), ColumnarValue::Array(right)) => {
42 Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?))
43 }
44 (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok(
45 ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?),
46 ),
47 (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok(
48 ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?),
49 ),
50 (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => {
51 let array = f(&left.to_scalar()?, &right.to_scalar()?)?;
52 let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?;
53 Ok(ColumnarValue::Scalar(scalar))
54 }
55 }
56}
57
58pub fn apply_cmp(
60 op: Operator,
61 lhs: &ColumnarValue,
62 rhs: &ColumnarValue,
63) -> Result<ColumnarValue> {
64 if lhs.data_type().is_nested() {
65 apply_cmp_for_nested(op, lhs, rhs)
66 } else {
67 let f = match op {
68 Operator::Eq => eq,
69 Operator::NotEq => neq,
70 Operator::Lt => lt,
71 Operator::LtEq => lt_eq,
72 Operator::Gt => gt,
73 Operator::GtEq => gt_eq,
74 Operator::IsDistinctFrom => distinct,
75 Operator::IsNotDistinctFrom => not_distinct,
76
77 Operator::LikeMatch => like,
78 Operator::ILikeMatch => ilike,
79 Operator::NotLikeMatch => nlike,
80 Operator::NotILikeMatch => nilike,
81
82 _ => {
83 return internal_err!("Invalid compare operator: {}", op);
84 }
85 };
86
87 apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?)))
88 }
89}
90
91pub fn apply_cmp_for_nested(
94 op: Operator,
95 lhs: &ColumnarValue,
96 rhs: &ColumnarValue,
97) -> Result<ColumnarValue> {
98 let left_data_type = lhs.data_type();
99 let right_data_type = rhs.data_type();
100
101 assert_or_internal_err!(
102 matches!(
103 op,
104 Operator::Eq
105 | Operator::NotEq
106 | Operator::Lt
107 | Operator::Gt
108 | Operator::LtEq
109 | Operator::GtEq
110 | Operator::IsDistinctFrom
111 | Operator::IsNotDistinctFrom
112 ) && left_data_type.equals_datatype(&right_data_type),
113 "invalid operator or data type mismatch for nested data, op {op} left {left_data_type}, right {right_data_type}",
114 );
115
116 apply(lhs, rhs, |l, r| {
117 Ok(Arc::new(compare_op_for_nested(op, l, r)?))
118 })
119}
120
121pub fn compare_with_eq(
123 lhs: &dyn Datum,
124 rhs: &dyn Datum,
125 is_nested: bool,
126) -> Result<BooleanArray> {
127 if is_nested {
128 compare_op_for_nested(Operator::Eq, lhs, rhs)
129 } else {
130 eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e))
131 }
132}
133
134pub fn compare_op_for_nested(
136 op: Operator,
137 lhs: &dyn Datum,
138 rhs: &dyn Datum,
139) -> Result<BooleanArray> {
140 let (l, is_l_scalar) = lhs.get();
141 let (r, is_r_scalar) = rhs.get();
142 let l_len = l.len();
143 let r_len = r.len();
144
145 assert_or_internal_err!(l_len == r_len || is_l_scalar || is_r_scalar, "len mismatch");
146
147 let len = match is_l_scalar {
148 true => r_len,
149 false => l_len,
150 };
151
152 if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom)
154 && (is_l_scalar && l.null_count() == 1 || is_r_scalar && r.null_count() == 1)
155 {
156 return Ok(BooleanArray::new_null(len));
157 }
158
159 let cmp = make_comparator(l, r, SortOptions::default())?;
162
163 let cmp_with_op = |i, j| match op {
164 Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(),
165 Operator::Lt => cmp(i, j).is_lt(),
166 Operator::Gt => cmp(i, j).is_gt(),
167 Operator::LtEq => !cmp(i, j).is_gt(),
168 Operator::GtEq => !cmp(i, j).is_lt(),
169 Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(),
170 _ => unreachable!("unexpected operator found"),
171 };
172
173 let values = match (is_l_scalar, is_r_scalar) {
174 (false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(),
175 (true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(),
176 (false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(),
177 (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(),
178 };
179
180 if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) {
183 Ok(BooleanArray::new(values, None))
184 } else {
185 let nulls = match (is_l_scalar, is_r_scalar) {
189 (false, false) | (true, true) => NullBuffer::union(l.nulls(), r.nulls()),
190 (true, false) => {
191 match l.nulls().filter(|nulls| !nulls.is_valid(0)) {
193 Some(_) => Some(NullBuffer::new_null(len)), None => r.nulls().cloned(), }
196 }
197 (false, true) => {
198 match r.nulls().filter(|nulls| !nulls.is_valid(0)) {
200 Some(_) => Some(NullBuffer::new_null(len)), None => l.nulls().cloned(), }
203 }
204 };
205 Ok(BooleanArray::new(values, nulls))
206 }
207}