datafusion_physical_expr_common/physical_expr.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt;
20use std::fmt::{Debug, Display, Formatter};
21use std::hash::{Hash, Hasher};
22use std::sync::Arc;
23
24use crate::utils::scatter;
25
26use arrow::array::{ArrayRef, BooleanArray, new_empty_array};
27use arrow::compute::filter_record_batch;
28use arrow::datatypes::{DataType, Field, FieldRef, Schema};
29use arrow::record_batch::RecordBatch;
30use datafusion_common::tree_node::{
31 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
32};
33use datafusion_common::{
34 Result, ScalarValue, assert_eq_or_internal_err, exec_err, not_impl_err,
35};
36use datafusion_expr_common::columnar_value::ColumnarValue;
37use datafusion_expr_common::interval_arithmetic::Interval;
38use datafusion_expr_common::placement::ExpressionPlacement;
39use datafusion_expr_common::sort_properties::ExprProperties;
40use datafusion_expr_common::statistics::Distribution;
41
42use itertools::izip;
43
44/// Shared [`PhysicalExpr`].
45pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
46
47/// [`PhysicalExpr`]s represent expressions such as `A + 1` or `CAST(c1 AS int)`.
48///
49/// `PhysicalExpr` knows its type, nullability and can be evaluated directly on
50/// a [`RecordBatch`] (see [`Self::evaluate`]).
51///
52/// `PhysicalExpr` are the physical counterpart to [`Expr`] used in logical
53/// planning. They are typically created from [`Expr`] by a [`PhysicalPlanner`]
54/// invoked from a higher level API
55///
56/// Some important examples of `PhysicalExpr` are:
57/// * [`Column`]: Represents a column at a given index in a RecordBatch
58///
59/// To create `PhysicalExpr` from `Expr`, see
60/// * [`SessionContext::create_physical_expr`]: A high level API
61/// * [`create_physical_expr`]: A low level API
62///
63/// # Formatting `PhysicalExpr` as strings
64/// There are three ways to format `PhysicalExpr` as a string:
65/// * [`Debug`]: Standard Rust debugging format (e.g. `Constant { value: ... }`)
66/// * [`Display`]: Detailed SQL-like format that shows expression structure (e.g. (`Utf8 ("foobar")`). This is often used for debugging and tests
67/// * [`Self::fmt_sql`]: SQL-like human readable format (e.g. ('foobar')`), See also [`sql_fmt`]
68///
69/// [`SessionContext::create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.create_physical_expr
70/// [`PhysicalPlanner`]: https://docs.rs/datafusion/latest/datafusion/physical_planner/trait.PhysicalPlanner.html
71/// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html
72/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html
73/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html
74pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
75 /// Returns the physical expression as [`Any`] so that it can be
76 /// downcast to a specific implementation.
77 fn as_any(&self) -> &dyn Any;
78 /// Get the data type of this expression, given the schema of the input
79 fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
80 Ok(self.return_field(input_schema)?.data_type().to_owned())
81 }
82 /// Determine whether this expression is nullable, given the schema of the input
83 fn nullable(&self, input_schema: &Schema) -> Result<bool> {
84 Ok(self.return_field(input_schema)?.is_nullable())
85 }
86 /// Evaluate an expression against a RecordBatch
87 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue>;
88 /// The output field associated with this expression
89 fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> {
90 Ok(Arc::new(Field::new(
91 format!("{self}"),
92 self.data_type(input_schema)?,
93 self.nullable(input_schema)?,
94 )))
95 }
96 /// Evaluate an expression against a RecordBatch after first applying a validity array
97 ///
98 /// # Errors
99 ///
100 /// Returns an `Err` if the expression could not be evaluated or if the length of the
101 /// `selection` validity array and the number of row in `batch` is not equal.
102 fn evaluate_selection(
103 &self,
104 batch: &RecordBatch,
105 selection: &BooleanArray,
106 ) -> Result<ColumnarValue> {
107 let row_count = batch.num_rows();
108 if row_count != selection.len() {
109 return exec_err!(
110 "Selection array length does not match batch row count: {} != {row_count}",
111 selection.len()
112 );
113 }
114
115 let selection_count = selection.true_count();
116
117 // First, check if we can avoid filtering altogether.
118 if selection_count == row_count {
119 // All values from the `selection` filter are true and match the input batch.
120 // No need to perform any filtering.
121 return self.evaluate(batch);
122 }
123
124 // Next, prepare the result array for each 'true' row in the selection vector.
125 let filtered_result = if selection_count == 0 {
126 // Do not call `evaluate` when the selection is empty.
127 // `evaluate_selection` is used to conditionally evaluate expressions.
128 // When the expression in question is fallible, evaluating it with an empty
129 // record batch may trigger a runtime error (e.g. division by zero).
130 //
131 // Instead, create an empty array matching the expected return type.
132 let datatype = self.data_type(batch.schema_ref().as_ref())?;
133 ColumnarValue::Array(new_empty_array(&datatype))
134 } else {
135 // If we reach this point, there's no other option than to filter the batch.
136 // This is a fairly costly operation since it requires creating partial copies
137 // (worst case of length `row_count - 1`) of all the arrays in the record batch.
138 // The resulting `filtered_batch` will contain `selection_count` rows.
139 let filtered_batch = filter_record_batch(batch, selection)?;
140 self.evaluate(&filtered_batch)?
141 };
142
143 // Finally, scatter the filtered result array so that the indices match the input rows again.
144 match &filtered_result {
145 ColumnarValue::Array(a) => {
146 scatter(selection, a.as_ref()).map(ColumnarValue::Array)
147 }
148 ColumnarValue::Scalar(ScalarValue::Boolean(value)) => {
149 // When the scalar is true or false, skip the scatter process
150 if let Some(v) = value {
151 if *v {
152 Ok(ColumnarValue::from(Arc::new(selection.clone()) as ArrayRef))
153 } else {
154 Ok(filtered_result)
155 }
156 } else {
157 let array = BooleanArray::from(vec![None; row_count]);
158 scatter(selection, &array).map(ColumnarValue::Array)
159 }
160 }
161 ColumnarValue::Scalar(_) => Ok(filtered_result),
162 }
163 }
164
165 /// Get a list of child PhysicalExpr that provide the input for this expr.
166 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>;
167
168 /// Returns a new PhysicalExpr where all children were replaced by new exprs.
169 fn with_new_children(
170 self: Arc<Self>,
171 children: Vec<Arc<dyn PhysicalExpr>>,
172 ) -> Result<Arc<dyn PhysicalExpr>>;
173
174 /// Computes the output interval for the expression, given the input
175 /// intervals.
176 ///
177 /// # Parameters
178 ///
179 /// * `children` are the intervals for the children (inputs) of this
180 /// expression.
181 ///
182 /// # Returns
183 ///
184 /// A `Result` containing the output interval for the expression in
185 /// case of success, or an error object in case of failure.
186 ///
187 /// # Example
188 ///
189 /// If the expression is `a + b`, and the input intervals are `a: [1, 2]`
190 /// and `b: [3, 4]`, then the output interval would be `[4, 6]`.
191 fn evaluate_bounds(&self, _children: &[&Interval]) -> Result<Interval> {
192 not_impl_err!("Not implemented for {self}")
193 }
194
195 /// Updates bounds for child expressions, given a known interval for this
196 /// expression.
197 ///
198 /// This is used to propagate constraints down through an expression tree.
199 ///
200 /// # Parameters
201 ///
202 /// * `interval` is the currently known interval for this expression.
203 /// * `children` are the current intervals for the children of this expression.
204 ///
205 /// # Returns
206 ///
207 /// A `Result` containing a `Vec` of new intervals for the children (in order)
208 /// in case of success, or an error object in case of failure.
209 ///
210 /// If constraint propagation reveals an infeasibility for any child, returns
211 /// [`None`]. If none of the children intervals change as a result of
212 /// propagation, may return an empty vector instead of cloning `children`.
213 /// This is the default (and conservative) return value.
214 ///
215 /// # Example
216 ///
217 /// If the expression is `a + b`, the current `interval` is `[4, 5]` and the
218 /// inputs `a` and `b` are respectively given as `[0, 2]` and `[-∞, 4]`, then
219 /// propagation would return `[0, 2]` and `[2, 4]` as `b` must be at least
220 /// `2` to make the output at least `4`.
221 fn propagate_constraints(
222 &self,
223 _interval: &Interval,
224 _children: &[&Interval],
225 ) -> Result<Option<Vec<Interval>>> {
226 Ok(Some(vec![]))
227 }
228
229 /// Computes the output statistics for the expression, given the input
230 /// statistics.
231 ///
232 /// # Parameters
233 ///
234 /// * `children` are the statistics for the children (inputs) of this
235 /// expression.
236 ///
237 /// # Returns
238 ///
239 /// A `Result` containing the output statistics for the expression in
240 /// case of success, or an error object in case of failure.
241 ///
242 /// Expressions (should) implement this function and utilize the independence
243 /// assumption, match on children distribution types and compute the output
244 /// statistics accordingly. The default implementation simply creates an
245 /// unknown output distribution by combining input ranges. This logic loses
246 /// distribution information, but is a safe default.
247 fn evaluate_statistics(&self, children: &[&Distribution]) -> Result<Distribution> {
248 let children_ranges = children
249 .iter()
250 .map(|c| c.range())
251 .collect::<Result<Vec<_>>>()?;
252 let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
253 let output_interval = self.evaluate_bounds(children_ranges_refs.as_slice())?;
254 let dt = output_interval.data_type();
255 if dt.eq(&DataType::Boolean) {
256 let p = if output_interval.eq(&Interval::TRUE) {
257 ScalarValue::new_one(&dt)
258 } else if output_interval.eq(&Interval::FALSE) {
259 ScalarValue::new_zero(&dt)
260 } else {
261 ScalarValue::try_from(&dt)
262 }?;
263 Distribution::new_bernoulli(p)
264 } else {
265 Distribution::new_from_interval(output_interval)
266 }
267 }
268
269 /// Updates children statistics using the given parent statistic for this
270 /// expression.
271 ///
272 /// This is used to propagate statistics down through an expression tree.
273 ///
274 /// # Parameters
275 ///
276 /// * `parent` is the currently known statistics for this expression.
277 /// * `children` are the current statistics for the children of this expression.
278 ///
279 /// # Returns
280 ///
281 /// A `Result` containing a `Vec` of new statistics for the children (in order)
282 /// in case of success, or an error object in case of failure.
283 ///
284 /// If statistics propagation reveals an infeasibility for any child, returns
285 /// [`None`]. If none of the children statistics change as a result of
286 /// propagation, may return an empty vector instead of cloning `children`.
287 /// This is the default (and conservative) return value.
288 ///
289 /// Expressions (should) implement this function and apply Bayes rule to
290 /// reconcile and update parent/children statistics. This involves utilizing
291 /// the independence assumption, and matching on distribution types. The
292 /// default implementation simply creates an unknown distribution if it can
293 /// narrow the range by propagating ranges. This logic loses distribution
294 /// information, but is a safe default.
295 fn propagate_statistics(
296 &self,
297 parent: &Distribution,
298 children: &[&Distribution],
299 ) -> Result<Option<Vec<Distribution>>> {
300 let children_ranges = children
301 .iter()
302 .map(|c| c.range())
303 .collect::<Result<Vec<_>>>()?;
304 let children_ranges_refs = children_ranges.iter().collect::<Vec<_>>();
305 let parent_range = parent.range()?;
306 let Some(propagated_children) =
307 self.propagate_constraints(&parent_range, children_ranges_refs.as_slice())?
308 else {
309 return Ok(None);
310 };
311 izip!(propagated_children.into_iter(), children_ranges, children)
312 .map(|(new_interval, old_interval, child)| {
313 if new_interval == old_interval {
314 // We weren't able to narrow the range, preserve the old statistics.
315 Ok((*child).clone())
316 } else if new_interval.data_type().eq(&DataType::Boolean) {
317 let dt = old_interval.data_type();
318 let p = if new_interval.eq(&Interval::TRUE) {
319 ScalarValue::new_one(&dt)
320 } else if new_interval.eq(&Interval::FALSE) {
321 ScalarValue::new_zero(&dt)
322 } else {
323 unreachable!("Given that we have a range reduction for a boolean interval, we should have certainty")
324 }?;
325 Distribution::new_bernoulli(p)
326 } else {
327 Distribution::new_from_interval(new_interval)
328 }
329 })
330 .collect::<Result<_>>()
331 .map(Some)
332 }
333
334 /// Calculates the properties of this [`PhysicalExpr`] based on its
335 /// children's properties (i.e. order and range), recursively aggregating
336 /// the information from its children. In cases where the [`PhysicalExpr`]
337 /// has no children (e.g., `Literal` or `Column`), these properties should
338 /// be specified externally, as the function defaults to unknown properties.
339 fn get_properties(&self, _children: &[ExprProperties]) -> Result<ExprProperties> {
340 Ok(ExprProperties::new_unknown())
341 }
342
343 /// Format this `PhysicalExpr` in nice human readable "SQL" format
344 ///
345 /// Specifically, this format is designed to be readable by humans, at the
346 /// expense of details. Use `Display` or `Debug` for more detailed
347 /// representation.
348 ///
349 /// See the [`fmt_sql`] function for an example of printing `PhysicalExpr`s as SQL.
350 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result;
351
352 /// Take a snapshot of this `PhysicalExpr`, if it is dynamic.
353 ///
354 /// "Dynamic" in this case means containing references to structures that may change
355 /// during plan execution, such as hash tables.
356 ///
357 /// This method is used to capture the current state of `PhysicalExpr`s that may contain
358 /// dynamic references to other operators in order to serialize it over the wire
359 /// or treat it via downcast matching.
360 ///
361 /// You should not call this method directly as it does not handle recursion.
362 /// Instead use [`snapshot_physical_expr`] to handle recursion and capture the
363 /// full state of the `PhysicalExpr`.
364 ///
365 /// This is expected to return "simple" expressions that do not have mutable state
366 /// and are composed of DataFusion's built-in `PhysicalExpr` implementations.
367 /// Callers however should *not* assume anything about the returned expressions
368 /// since callers and implementers may not agree on what "simple" or "built-in"
369 /// means.
370 /// In other words, if you need to serialize a `PhysicalExpr` across the wire
371 /// you should call this method and then try to serialize the result,
372 /// but you should handle unknown or unexpected `PhysicalExpr` implementations gracefully
373 /// just as if you had not called this method at all.
374 ///
375 /// In particular, consider:
376 /// * A `PhysicalExpr` that references the current state of a `datafusion::physical_plan::TopK`
377 /// that is involved in a query with `SELECT * FROM t1 ORDER BY a LIMIT 10`.
378 /// This function may return something like `a >= 12`.
379 /// * A `PhysicalExpr` that references the current state of a `datafusion::physical_plan::joins::HashJoinExec`
380 /// from a query such as `SELECT * FROM t1 JOIN t2 ON t1.a = t2.b`.
381 /// This function may return something like `t2.b IN (1, 5, 7)`.
382 ///
383 /// A system or function that can only deal with a hardcoded set of `PhysicalExpr` implementations
384 /// or needs to serialize this state to bytes may not be able to handle these dynamic references.
385 /// In such cases, we should return a simplified version of the `PhysicalExpr` that does not
386 /// contain these dynamic references.
387 ///
388 /// Systems that implement remote execution of plans, e.g. serialize a portion of the query plan
389 /// and send it across the wire to a remote executor may want to call this method after
390 /// every batch on the source side and broadcast / update the current snapshot to the remote executor.
391 ///
392 /// Note for implementers: this method should *not* handle recursion.
393 /// Recursion is handled in [`snapshot_physical_expr`].
394 fn snapshot(&self) -> Result<Option<Arc<dyn PhysicalExpr>>> {
395 // By default, we return None to indicate that this PhysicalExpr does not
396 // have any dynamic references or state.
397 // This is a safe default behavior.
398 Ok(None)
399 }
400
401 /// Returns the generation of this `PhysicalExpr` for snapshotting purposes.
402 /// The generation is an arbitrary u64 that can be used to track changes
403 /// in the state of the `PhysicalExpr` over time without having to do an exhaustive comparison.
404 /// This is useful to avoid unnecessary computation or serialization if there are no changes to the expression.
405 /// In particular, dynamic expressions that may change over time; this allows cheap checks for changes.
406 /// Static expressions that do not change over time should return 0, as does the default implementation.
407 /// You should not call this method directly as it does not handle recursion.
408 /// Instead use [`snapshot_generation`] to handle recursion and capture the
409 /// full state of the `PhysicalExpr`.
410 fn snapshot_generation(&self) -> u64 {
411 // By default, we return 0 to indicate that this PhysicalExpr does not
412 // have any dynamic references or state.
413 // Since the recursive algorithm XORs the generations of all children the overall
414 // generation will be 0 if no children have a non-zero generation, meaning that
415 // static expressions will always return 0.
416 0
417 }
418
419 /// Returns true if the expression node is volatile, i.e. whether it can return
420 /// different results when evaluated multiple times with the same input.
421 ///
422 /// Note: unlike [`is_volatile`], this function does not consider inputs:
423 /// - `random()` returns `true`,
424 /// - `a + random()` returns `false` (because the operation `+` itself is not volatile.)
425 ///
426 /// The default to this function was set to `false` when it was created
427 /// to avoid imposing API churn on implementers, but this is not a safe default in general.
428 /// It is highly recommended that volatile expressions implement this method and return `true`.
429 /// This default may be removed in the future if it causes problems or we decide to
430 /// eat the cost of the breaking change and require all implementers to make a choice.
431 fn is_volatile_node(&self) -> bool {
432 false
433 }
434
435 /// Returns placement information for this expression.
436 ///
437 /// This is used by optimizers to make decisions about expression placement,
438 /// such as whether to push expressions down through projections.
439 ///
440 /// The default implementation returns [`ExpressionPlacement::KeepInPlace`].
441 fn placement(&self) -> ExpressionPlacement {
442 ExpressionPlacement::KeepInPlace
443 }
444}
445
446#[deprecated(
447 since = "50.0.0",
448 note = "Use `datafusion_expr_common::dyn_eq` instead"
449)]
450pub use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
451
452impl PartialEq for dyn PhysicalExpr {
453 fn eq(&self, other: &Self) -> bool {
454 self.dyn_eq(other.as_any())
455 }
456}
457impl Eq for dyn PhysicalExpr {}
458
459impl Hash for dyn PhysicalExpr {
460 fn hash<H: Hasher>(&self, state: &mut H) {
461 self.dyn_hash(state);
462 }
463}
464
465/// Returns a copy of this expr if we change any child according to the pointer comparison.
466/// The size of `children` must be equal to the size of `PhysicalExpr::children()`.
467pub fn with_new_children_if_necessary(
468 expr: Arc<dyn PhysicalExpr>,
469 children: Vec<Arc<dyn PhysicalExpr>>,
470) -> Result<Arc<dyn PhysicalExpr>> {
471 let old_children = expr.children();
472 assert_eq_or_internal_err!(
473 children.len(),
474 old_children.len(),
475 "PhysicalExpr: Wrong number of children"
476 );
477
478 if children.is_empty()
479 || children
480 .iter()
481 .zip(old_children.iter())
482 .any(|(c1, c2)| !Arc::ptr_eq(c1, c2))
483 {
484 Ok(expr.with_new_children(children)?)
485 } else {
486 Ok(expr)
487 }
488}
489
490/// Returns [`Display`] able a list of [`PhysicalExpr`]
491///
492/// Example output: `[a + 1, b]`
493pub fn format_physical_expr_list<T>(exprs: T) -> impl Display
494where
495 T: IntoIterator,
496 T::Item: Display,
497 T::IntoIter: Clone,
498{
499 struct DisplayWrapper<I>(I)
500 where
501 I: Iterator + Clone,
502 I::Item: Display;
503
504 impl<I> Display for DisplayWrapper<I>
505 where
506 I: Iterator + Clone,
507 I::Item: Display,
508 {
509 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
510 let mut iter = self.0.clone();
511 write!(f, "[")?;
512 if let Some(expr) = iter.next() {
513 write!(f, "{expr}")?;
514 }
515 for expr in iter {
516 write!(f, ", {expr}")?;
517 }
518 write!(f, "]")?;
519 Ok(())
520 }
521 }
522
523 DisplayWrapper(exprs.into_iter())
524}
525
526/// Prints a [`PhysicalExpr`] in a SQL-like format
527///
528/// # Example
529/// ```
530/// # // The boilerplate needed to create a `PhysicalExpr` for the example
531/// # use std::any::Any;
532/// use std::collections::HashMap;
533/// # use std::fmt::Formatter;
534/// # use std::sync::Arc;
535/// # use arrow::array::RecordBatch;
536/// # use arrow::datatypes::{DataType, Field, FieldRef, Schema};
537/// # use datafusion_common::Result;
538/// # use datafusion_expr_common::columnar_value::ColumnarValue;
539/// # use datafusion_physical_expr_common::physical_expr::{fmt_sql, DynEq, PhysicalExpr};
540/// # #[derive(Debug, PartialEq, Eq, Hash)]
541/// # struct MyExpr {}
542/// # impl PhysicalExpr for MyExpr {fn as_any(&self) -> &dyn Any { unimplemented!() }
543/// # fn data_type(&self, input_schema: &Schema) -> Result<DataType> { unimplemented!() }
544/// # fn nullable(&self, input_schema: &Schema) -> Result<bool> { unimplemented!() }
545/// # fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { unimplemented!() }
546/// # fn return_field(&self, input_schema: &Schema) -> Result<FieldRef> { unimplemented!() }
547/// # fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>>{ unimplemented!() }
548/// # fn with_new_children(self: Arc<Self>, children: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn PhysicalExpr>> { unimplemented!() }
549/// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") }
550/// # }
551/// # impl std::fmt::Display for MyExpr {fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { unimplemented!() } }
552/// # fn make_physical_expr() -> Arc<dyn PhysicalExpr> { Arc::new(MyExpr{}) }
553/// let expr: Arc<dyn PhysicalExpr> = make_physical_expr();
554/// // wrap the expression in `sql_fmt` which can be used with
555/// // `format!`, `to_string()`, etc
556/// let expr_as_sql = fmt_sql(expr.as_ref());
557/// assert_eq!(
558/// "The SQL: CASE a > b THEN 1 ELSE 0 END",
559/// format!("The SQL: {expr_as_sql}")
560/// );
561/// ```
562pub fn fmt_sql(expr: &dyn PhysicalExpr) -> impl Display + '_ {
563 struct Wrapper<'a> {
564 expr: &'a dyn PhysicalExpr,
565 }
566
567 impl Display for Wrapper<'_> {
568 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
569 self.expr.fmt_sql(f)?;
570 Ok(())
571 }
572 }
573
574 Wrapper { expr }
575}
576
577/// Take a snapshot of the given `PhysicalExpr` if it is dynamic.
578///
579/// Take a snapshot of this `PhysicalExpr` if it is dynamic.
580/// This is used to capture the current state of `PhysicalExpr`s that may contain
581/// dynamic references to other operators in order to serialize it over the wire
582/// or treat it via downcast matching.
583///
584/// See the documentation of [`PhysicalExpr::snapshot`] for more details.
585///
586/// # Returns
587///
588/// Returns a snapshot of the `PhysicalExpr` if it is dynamic, otherwise
589/// returns itself.
590pub fn snapshot_physical_expr(
591 expr: Arc<dyn PhysicalExpr>,
592) -> Result<Arc<dyn PhysicalExpr>> {
593 snapshot_physical_expr_opt(expr).data()
594}
595
596/// Take a snapshot of the given `PhysicalExpr` if it is dynamic.
597///
598/// Take a snapshot of this `PhysicalExpr` if it is dynamic.
599/// This is used to capture the current state of `PhysicalExpr`s that may contain
600/// dynamic references to other operators in order to serialize it over the wire
601/// or treat it via downcast matching.
602///
603/// See the documentation of [`PhysicalExpr::snapshot`] for more details.
604///
605/// # Returns
606///
607/// Returns a `[`Transformed`] indicating whether a snapshot was taken,
608/// along with the resulting `PhysicalExpr`.
609pub fn snapshot_physical_expr_opt(
610 expr: Arc<dyn PhysicalExpr>,
611) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
612 expr.transform_up(|e| {
613 if let Some(snapshot) = e.snapshot()? {
614 Ok(Transformed::yes(snapshot))
615 } else {
616 Ok(Transformed::no(Arc::clone(&e)))
617 }
618 })
619}
620
621/// Check the generation of this `PhysicalExpr`.
622/// Dynamic `PhysicalExpr`s may have a generation that is incremented
623/// every time the state of the `PhysicalExpr` changes.
624/// If the generation changes that means this `PhysicalExpr` or one of its children
625/// has changed since the last time it was evaluated.
626///
627/// This algorithm will not produce collisions as long as the structure of the
628/// `PhysicalExpr` does not change and no `PhysicalExpr` decrements its own generation.
629pub fn snapshot_generation(expr: &Arc<dyn PhysicalExpr>) -> u64 {
630 let mut generation = 0u64;
631 expr.apply(|e| {
632 // Add the current generation of the `PhysicalExpr` to our global generation.
633 generation = generation.wrapping_add(e.snapshot_generation());
634 Ok(TreeNodeRecursion::Continue)
635 })
636 .expect("this traversal is infallible");
637
638 generation
639}
640
641/// Check if the given `PhysicalExpr` is dynamic.
642/// Internally this calls [`snapshot_generation`] to check if the generation is non-zero,
643/// any dynamic `PhysicalExpr` should have a non-zero generation.
644pub fn is_dynamic_physical_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
645 // If the generation is non-zero, then this `PhysicalExpr` is dynamic.
646 snapshot_generation(expr) != 0
647}
648
649/// Returns true if the expression is volatile, i.e. whether it can return different
650/// results when evaluated multiple times with the same input.
651///
652/// For example the function call `RANDOM()` is volatile as each call will
653/// return a different value.
654///
655/// This method recursively checks if any sub-expression is volatile, for example
656/// `1 + RANDOM()` will return `true`.
657pub fn is_volatile(expr: &Arc<dyn PhysicalExpr>) -> bool {
658 if expr.is_volatile_node() {
659 return true;
660 }
661 let mut is_volatile = false;
662 expr.apply(|e| {
663 if e.is_volatile_node() {
664 is_volatile = true;
665 Ok(TreeNodeRecursion::Stop)
666 } else {
667 Ok(TreeNodeRecursion::Continue)
668 }
669 })
670 .expect("infallible closure should not fail");
671 is_volatile
672}
673
674#[cfg(test)]
675mod test {
676 use crate::physical_expr::PhysicalExpr;
677 use arrow::array::{Array, BooleanArray, Int64Array, RecordBatch};
678 use arrow::datatypes::{DataType, Schema};
679 use datafusion_expr_common::columnar_value::ColumnarValue;
680 use std::fmt::{Display, Formatter};
681 use std::sync::Arc;
682
683 #[derive(Debug, PartialEq, Eq, Hash)]
684 struct TestExpr {}
685
686 impl PhysicalExpr for TestExpr {
687 fn as_any(&self) -> &dyn std::any::Any {
688 self
689 }
690
691 fn data_type(&self, _schema: &Schema) -> datafusion_common::Result<DataType> {
692 Ok(DataType::Int64)
693 }
694
695 fn nullable(&self, _schema: &Schema) -> datafusion_common::Result<bool> {
696 Ok(false)
697 }
698
699 fn evaluate(
700 &self,
701 batch: &RecordBatch,
702 ) -> datafusion_common::Result<ColumnarValue> {
703 let data = vec![1; batch.num_rows()];
704 Ok(ColumnarValue::Array(Arc::new(Int64Array::from(data))))
705 }
706
707 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
708 vec![]
709 }
710
711 fn with_new_children(
712 self: Arc<Self>,
713 _children: Vec<Arc<dyn PhysicalExpr>>,
714 ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
715 Ok(Arc::new(Self {}))
716 }
717
718 fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
719 f.write_str("TestExpr")
720 }
721 }
722
723 impl Display for TestExpr {
724 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
725 self.fmt_sql(f)
726 }
727 }
728
729 macro_rules! assert_arrays_eq {
730 ($EXPECTED: expr, $ACTUAL: expr, $MESSAGE: expr) => {
731 let expected = $EXPECTED.to_array(1).unwrap();
732 let actual = $ACTUAL;
733
734 let actual_array = actual.to_array(expected.len()).unwrap();
735 let actual_ref = actual_array.as_ref();
736 let expected_ref = expected.as_ref();
737 assert!(
738 actual_ref == expected_ref,
739 "{}: expected: {:?}, actual: {:?}",
740 $MESSAGE,
741 $EXPECTED,
742 actual_ref
743 );
744 };
745 }
746
747 fn test_evaluate_selection(
748 batch: &RecordBatch,
749 selection: &BooleanArray,
750 expected: &ColumnarValue,
751 ) {
752 let expr = TestExpr {};
753
754 // First check that the `evaluate_selection` is the expected one
755 let selection_result = expr.evaluate_selection(batch, selection).unwrap();
756 assert_eq!(
757 expected.to_array(1).unwrap().len(),
758 selection_result.to_array(1).unwrap().len(),
759 "evaluate_selection should output row count should match input record batch"
760 );
761 assert_arrays_eq!(
762 expected,
763 &selection_result,
764 "evaluate_selection returned unexpected value"
765 );
766
767 // If we're selecting all rows, the result should be the same as calling `evaluate`
768 // with the full record batch.
769 if (0..batch.num_rows())
770 .all(|row_idx| row_idx < selection.len() && selection.value(row_idx))
771 {
772 let empty_result = expr.evaluate(batch).unwrap();
773
774 assert_arrays_eq!(
775 empty_result,
776 &selection_result,
777 "evaluate_selection does not match unfiltered evaluate result"
778 );
779 }
780 }
781
782 fn test_evaluate_selection_error(batch: &RecordBatch, selection: &BooleanArray) {
783 let expr = TestExpr {};
784
785 // First check that the `evaluate_selection` is the expected one
786 let selection_result = expr.evaluate_selection(batch, selection);
787 assert!(selection_result.is_err(), "evaluate_selection should fail");
788 }
789
790 #[test]
791 pub fn test_evaluate_selection_with_empty_record_batch() {
792 test_evaluate_selection(
793 &RecordBatch::new_empty(Arc::new(Schema::empty())),
794 &BooleanArray::from(vec![false; 0]),
795 &ColumnarValue::Array(Arc::new(Int64Array::new_null(0))),
796 );
797 }
798
799 #[test]
800 pub fn test_evaluate_selection_with_empty_record_batch_with_larger_false_selection() {
801 test_evaluate_selection_error(
802 &RecordBatch::new_empty(Arc::new(Schema::empty())),
803 &BooleanArray::from(vec![false; 10]),
804 );
805 }
806
807 #[test]
808 pub fn test_evaluate_selection_with_empty_record_batch_with_larger_true_selection() {
809 test_evaluate_selection_error(
810 &RecordBatch::new_empty(Arc::new(Schema::empty())),
811 &BooleanArray::from(vec![true; 10]),
812 );
813 }
814
815 #[test]
816 pub fn test_evaluate_selection_with_non_empty_record_batch() {
817 test_evaluate_selection(
818 &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
819 &BooleanArray::from(vec![true; 10]),
820 &ColumnarValue::Array(Arc::new(Int64Array::from(vec![1; 10]))),
821 );
822 }
823
824 #[test]
825 pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_false_selection()
826 {
827 test_evaluate_selection_error(
828 &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
829 &BooleanArray::from(vec![false; 20]),
830 );
831 }
832
833 #[test]
834 pub fn test_evaluate_selection_with_non_empty_record_batch_with_larger_true_selection()
835 {
836 test_evaluate_selection_error(
837 &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
838 &BooleanArray::from(vec![true; 20]),
839 );
840 }
841
842 #[test]
843 pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_false_selection()
844 {
845 test_evaluate_selection_error(
846 &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
847 &BooleanArray::from(vec![false; 5]),
848 );
849 }
850
851 #[test]
852 pub fn test_evaluate_selection_with_non_empty_record_batch_with_smaller_true_selection()
853 {
854 test_evaluate_selection_error(
855 &unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 10) },
856 &BooleanArray::from(vec![true; 5]),
857 );
858 }
859}