datafusion_physical_expr_common/physical_expr.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt::{Debug, Display, Formatter};
20use std::hash::{Hash, Hasher};
21use std::sync::Arc;
22
23use crate::utils::scatter;
24
25use arrow::array::BooleanArray;
26use arrow::compute::filter_record_batch;
27use arrow::datatypes::{DataType, Schema};
28use arrow::record_batch::RecordBatch;
29use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
30use datafusion_expr_common::columnar_value::ColumnarValue;
31use datafusion_expr_common::interval_arithmetic::Interval;
32use datafusion_expr_common::sort_properties::ExprProperties;
33use datafusion_expr_common::statistics::Distribution;
34
35use itertools::izip;
36
37/// Shared [`PhysicalExpr`].
38pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
39
40/// [`PhysicalExpr`]s represent expressions such as `A + 1` or `CAST(c1 AS int)`.
41///
42/// `PhysicalExpr` knows its type, nullability and can be evaluated directly on
43/// a [`RecordBatch`] (see [`Self::evaluate`]).
44///
45/// `PhysicalExpr` are the physical counterpart to [`Expr`] used in logical
46/// planning. They are typically created from [`Expr`] by a [`PhysicalPlanner`]
47/// invoked from a higher level API
48///
49/// Some important examples of `PhysicalExpr` are:
50/// * [`Column`]: Represents a column at a given index in a RecordBatch
51///
52/// To create `PhysicalExpr` from `Expr`, see
53/// * [`SessionContext::create_physical_expr`]: A high level API
54/// * [`create_physical_expr`]: A low level API
55///
56/// [`SessionContext::create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.create_physical_expr
57/// [`PhysicalPlanner`]: https://docs.rs/datafusion/latest/datafusion/physical_planner/trait.PhysicalPlanner.html
58/// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html
59/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html
60/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html
61pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash {
62 /// Returns the physical expression as [`Any`] so that it can be
63 /// downcast to a specific implementation.
64 fn as_any(&self) -> &dyn Any;
65 /// Get the data type of this expression, given the schema of the input
66 fn data_type(&self, input_schema: &Schema) -> Result<DataType>;
67 /// Determine whether this expression is nullable, given the schema of the input
68 fn nullable(&self, input_schema: &Schema) -> Result<bool>;
69 /// Evaluate an expression against a RecordBatch
70 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>;
71 /// Evaluate an expression against a RecordBatch after first applying a
72 /// validity array
73 fn evaluate_selection(
74 &self,
75 batch: &RecordBatch,
76 selection: &BooleanArray,
77 ) -> Result<ColumnarValue> {
78 let tmp_batch = filter_record_batch(batch, selection)?;
79
80 let tmp_result = self.evaluate(&tmp_batch)?;
81
82 if batch.num_rows() == tmp_batch.num_rows() {
83 // All values from the `selection` filter are true.
84 Ok(tmp_result)
85 } else if let ColumnarValue::Array(a) = tmp_result {
86 scatter(selection, a.as_ref()).map(ColumnarValue::Array)
87 } else {
88 Ok(tmp_result)
89 }
90 }
91
92 /// Get a list of child PhysicalExpr that provide the input for this expr.
93 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>;
94
95 /// Returns a new PhysicalExpr where all children were replaced by new exprs.
96 fn with_new_children(
97 self: Arc<Self>,
98 children: Vec<Arc<dyn PhysicalExpr>>,
99 ) -> Result<Arc<dyn PhysicalExpr>>;
100
101 /// Computes the output interval for the expression, given the input
102 /// intervals.
103 ///
104 /// # Parameters
105 ///
106 /// * `children` are the intervals for the children (inputs) of this
107 /// expression.
108 ///
109 /// # Returns
110 ///
111 /// A `Result` containing the output interval for the expression in
112 /// case of success, or an error object in case of failure.
113 ///
114 /// # Example
115 ///
116 /// If the expression is `a + b`, and the input intervals are `a: [1, 2]`
117 /// and `b: [3, 4]`, then the output interval would be `[4, 6]`.
118 fn evaluate_bounds(&self, _children: &[&Interval]) -> Result<Interval> {
119 not_impl_err!("Not implemented for {self}")
120 }
121
122 /// Updates bounds for child expressions, given a known interval for this
123 /// expression.
124 ///
125 /// This is used to propagate constraints down through an expression tree.
126 ///
127 /// # Parameters
128 ///
129 /// * `interval` is the currently known interval for this expression.
130 /// * `children` are the current intervals for the children of this expression.
131 ///
132 /// # Returns
133 ///
134 /// A `Result` containing a `Vec` of new intervals for the children (in order)
135 /// in case of success, or an error object in case of failure.
136 ///
137 /// If constraint propagation reveals an infeasibility for any child, returns
138 /// [`None`]. If none of the children intervals change as a result of
139 /// propagation, may return an empty vector instead of cloning `children`.
140 /// This is the default (and conservative) return value.
141 ///
142 /// # Example
143 ///
144 /// If the expression is `a + b`, the current `interval` is `[4, 5]` and the
145 /// inputs `a` and `b` are respectively given as `[0, 2]` and `[-∞, 4]`, then
146 /// propagation would return `[0, 2]` and `[2, 4]` as `b` must be at least
147 /// `2` to make the output at least `4`.
148 fn propagate_constraints(
149 &self,
150 _interval: &Interval,
151 _children: &[&Interval],
152 ) -> Result<Option<Vec<Interval>>> {
153 Ok(Some(vec![]))
154 }
155
156 /// Computes the output statistics for the expression, given the input
157 /// statistics.
158 ///
159 /// # Parameters
160 ///
161 /// * `children` are the statistics for the children (inputs) of this
162 /// expression.
163 ///
164 /// # Returns
165 ///
166 /// A `Result` containing the output statistics for the expression in
167 /// case of success, or an error object in case of failure.
168 ///
169 /// Expressions (should) implement this function and utilize the independence
170 /// assumption, match on children distribution types and compute the output
171 /// statistics accordingly. The default implementation simply creates an
172 /// unknown output distribution by combining input ranges. This logic loses
173 /// distribution information, but is a safe default.
174 fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
175 let children_ranges = children
176 .iter()
177 .map(|c| c.range())
178 .collect::<Result<Vec<_>>>()?;
179 let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
180 let output_interval = self.evaluate_bounds(children_ranges_refs.as_slice())?;
181 let dt = output_interval.data_type();
182 if dt.eq(&DataType::Boolean) {
183 let p = if output_interval.eq(&Interval::CERTAINLY_TRUE) {
184 ScalarValue::new_one(&dt)
185 } else if output_interval.eq(&Interval::CERTAINLY_FALSE) {
186 ScalarValue::new_zero(&dt)
187 } else {
188 ScalarValue::try_from(&dt)
189 }?;
190 Distribution::new_bernoulli(p)
191 } else {
192 Distribution::new_from_interval(output_interval)
193 }
194 }
195
196 /// Updates children statistics using the given parent statistic for this
197 /// expression.
198 ///
199 /// This is used to propagate statistics down through an expression tree.
200 ///
201 /// # Parameters
202 ///
203 /// * `parent` is the currently known statistics for this expression.
204 /// * `children` are the current statistics for the children of this expression.
205 ///
206 /// # Returns
207 ///
208 /// A `Result` containing a `Vec` of new statistics for the children (in order)
209 /// in case of success, or an error object in case of failure.
210 ///
211 /// If statistics propagation reveals an infeasibility for any child, returns
212 /// [`None`]. If none of the children statistics change as a result of
213 /// propagation, may return an empty vector instead of cloning `children`.
214 /// This is the default (and conservative) return value.
215 ///
216 /// Expressions (should) implement this function and apply Bayes rule to
217 /// reconcile and update parent/children statistics. This involves utilizing
218 /// the independence assumption, and matching on distribution types. The
219 /// default implementation simply creates an unknown distribution if it can
220 /// narrow the range by propagating ranges. This logic loses distribution
221 /// information, but is a safe default.
222 fn propagate_statistics(
223 &self,
224 parent: &Distribution,
225 children: &[&Distribution],
226 ) -> Result<Option<Vec<Distribution>>> {
227 let children_ranges = children
228 .iter()
229 .map(|c| c.range())
230 .collect::<Result<Vec<_>>>()?;
231 let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
232 let parent_range = parent.range()?;
233 let Some(propagated_children) =
234 self.propagate_constraints(&parent_range, children_ranges_refs.as_slice())?
235 else {
236 return Ok(None);
237 };
238 izip!(propagated_children.into_iter(), children_ranges, children)
239 .map(|(new_interval, old_interval, child)| {
240 if new_interval == old_interval {
241 // We weren't able to narrow the range, preserve the old statistics.
242 Ok((*child).clone())
243 } else if new_interval.data_type().eq(&DataType::Boolean) {
244 let dt = old_interval.data_type();
245 let p = if new_interval.eq(&Interval::CERTAINLY_TRUE) {
246 ScalarValue::new_one(&dt)
247 } else if new_interval.eq(&Interval::CERTAINLY_FALSE) {
248 ScalarValue::new_zero(&dt)
249 } else {
250 unreachable!("Given that we have a range reduction for a boolean interval, we should have certainty")
251 }?;
252 Distribution::new_bernoulli(p)
253 } else {
254 Distribution::new_from_interval(new_interval)
255 }
256 })
257 .collect::<Result<_>>()
258 .map(Some)
259 }
260
261 /// Calculates the properties of this [`PhysicalExpr`] based on its
262 /// children's properties (i.e. order and range), recursively aggregating
263 /// the information from its children. In cases where the [`PhysicalExpr`]
264 /// has no children (e.g., `Literal` or `Column`), these properties should
265 /// be specified externally, as the function defaults to unknown properties.
266 fn get_properties(&self, _children: &[ExprProperties]) -> Result<ExprProperties> {
267 Ok(ExprProperties::new_unknown())
268 }
269}
270
271/// [`PhysicalExpr`] can't be constrained by [`Eq`] directly because it must remain object
272/// safe. To ease implementation, blanket implementation is provided for [`Eq`] types.
273pub trait DynEq {
274 fn dyn_eq(&self, other: &dyn Any) -> bool;
275}
276
277impl<T: Eq + Any> DynEq for T {
278 fn dyn_eq(&self, other: &dyn Any) -> bool {
279 other.downcast_ref::<Self>() == Some(self)
280 }
281}
282
283impl PartialEq for dyn PhysicalExpr {
284 fn eq(&self, other: &Self) -> bool {
285 self.dyn_eq(other.as_any())
286 }
287}
288
289impl Eq for dyn PhysicalExpr {}
290
291/// [`PhysicalExpr`] can't be constrained by [`Hash`] directly because it must remain
292/// object safe. To ease implementation blanket implementation is provided for [`Hash`]
293/// types.
294pub trait DynHash {
295 fn dyn_hash(&self, _state: &mut dyn Hasher);
296}
297
298impl<T: Hash + Any> DynHash for T {
299 fn dyn_hash(&self, mut state: &mut dyn Hasher) {
300 self.type_id().hash(&mut state);
301 self.hash(&mut state)
302 }
303}
304
305impl Hash for dyn PhysicalExpr {
306 fn hash<H: Hasher>(&self, state: &mut H) {
307 self.dyn_hash(state);
308 }
309}
310
311/// Returns a copy of this expr if we change any child according to the pointer comparison.
312/// The size of `children` must be equal to the size of `PhysicalExpr::children()`.
313pub fn with_new_children_if_necessary(
314 expr: Arc<dyn PhysicalExpr>,
315 children: Vec<Arc<dyn PhysicalExpr>>,
316) -> Result<Arc<dyn PhysicalExpr>> {
317 let old_children = expr.children();
318 if children.len() != old_children.len() {
319 internal_err!("PhysicalExpr: Wrong number of children")
320 } else if children.is_empty()
321 || children
322 .iter()
323 .zip(old_children.iter())
324 .any(|(c1, c2)| !Arc::ptr_eq(c1, c2))
325 {
326 Ok(expr.with_new_children(children)?)
327 } else {
328 Ok(expr)
329 }
330}
331
332#[deprecated(since = "44.0.0")]
333pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
334 if any.is::<Arc<dyn PhysicalExpr>>() {
335 any.downcast_ref::<Arc<dyn PhysicalExpr>>()
336 .unwrap()
337 .as_any()
338 } else if any.is::<Box<dyn PhysicalExpr>>() {
339 any.downcast_ref::<Box<dyn PhysicalExpr>>()
340 .unwrap()
341 .as_any()
342 } else {
343 any
344 }
345}
346
347/// Returns [`Display`] able a list of [`PhysicalExpr`]
348///
349/// Example output: `[a + 1, b]`
350pub fn format_physical_expr_list<T>(exprs: T) -> impl Display
351where
352 T: IntoIterator,
353 T::Item: Display,
354 T::IntoIter: Clone,
355{
356 struct DisplayWrapper<I>(I)
357 where
358 I: Iterator + Clone,
359 I::Item: Display;
360
361 impl<I> Display for DisplayWrapper<I>
362 where
363 I: Iterator + Clone,
364 I::Item: Display,
365 {
366 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
367 let mut iter = self.0.clone();
368 write!(f, "[")?;
369 if let Some(expr) = iter.next() {
370 write!(f, "{}", expr)?;
371 }
372 for expr in iter {
373 write!(f, ", {}", expr)?;
374 }
375 write!(f, "]")?;
376 Ok(())
377 }
378 }
379
380 DisplayWrapper(exprs.into_iter())
381}